You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/01/18 23:39:51 UTC
[airflow] branch master updated: Fix race conditions in task
callback invocations (#10917)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new f1d4f54 Fix race conditions in task callback invocations (#10917)
f1d4f54 is described below
commit f1d4f54b3479cd7549ce79efadd25cc6859dd420
Author: QP Hou <qp...@scribd.com>
AuthorDate: Mon Jan 18 15:39:41 2021 -0800
Fix race conditions in task callback invocations (#10917)
This race condition resulted in task success and failure callbacks being
called more than once. Here is the order of events that could lead to
this issue:
* task started running within process 2
* (process 1) local_task_job checked for task return code, returns None
* (process 2) task exited with failure state, task state updated as failed in DB
* (process 2) task failure callback invoked through taskinstance.handle_failure method
* (process 1) local_task_job heartbeat noticed task state set to
failure, mistoken it as state bing updated externally, also invoked task
failure callback
To avoid this race condition, we need to make sure task callbacks are
only invoked within a single process.
---
airflow/cli/cli_parser.py | 2 +
airflow/cli/commands/task_command.py | 5 +-
airflow/executors/debug_executor.py | 4 +
airflow/jobs/backfill_job.py | 2 +-
airflow/jobs/local_task_job.py | 36 +++--
airflow/jobs/scheduler_job.py | 6 +-
airflow/models/dag.py | 2 +-
airflow/models/taskinstance.py | 139 +++++++++++++-----
airflow/task/task_runner/base_task_runner.py | 35 +++--
airflow/task/task_runner/standard_task_runner.py | 3 +-
airflow/utils/process_utils.py | 7 +-
tests/core/test_core.py | 4 +-
tests/jobs/test_local_task_job.py | 156 ++++++++++++++++-----
tests/jobs/test_scheduler_job.py | 5 +-
tests/models/test_taskinstance.py | 6 +-
.../apache/hive/transfers/test_mysql_to_hive.py | 67 +++++----
16 files changed, 343 insertions(+), 136 deletions(-)
diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index 35949ea..8c52c86 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -405,6 +405,7 @@ ARG_SHIP_DAG = Arg(
("--ship-dag",), help="Pickles (serializes) the DAG and ships it to the worker", action="store_true"
)
ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entire dag (used internally)")
+ARG_ERROR_FILE = Arg(("--error-file",), help="File to store task failure error")
ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS)
ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg")
ARG_MIGRATION_TIMEOUT = Arg(
@@ -962,6 +963,7 @@ TASKS_COMMANDS = (
ARG_PICKLE,
ARG_JOB_ID,
ARG_INTERACTIVE,
+ ARG_ERROR_FILE,
ARG_SHUT_DOWN_LOGGING,
),
),
diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index c3bce77..b89ecfa 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -47,7 +47,7 @@ from airflow.utils.net import get_hostname
from airflow.utils.session import create_session
-def _run_task_by_selected_method(args, dag, ti):
+def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None:
"""
Runs the task in one of 3 modes
@@ -132,7 +132,7 @@ RAW_TASK_UNSUPPORTED_OPTION = [
]
-def _run_raw_task(args, ti):
+def _run_raw_task(args, ti: TaskInstance) -> None:
"""Runs the main task handling code"""
unsupported_options = [o for o in RAW_TASK_UNSUPPORTED_OPTION if getattr(args, o)]
@@ -149,6 +149,7 @@ def _run_raw_task(args, ti):
mark_success=args.mark_success,
job_id=args.job_id,
pool=args.pool,
+ error_file=args.error_file,
)
diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py
index 580dc65..25aace0 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -66,6 +66,7 @@ class DebugExecutor(BaseExecutor):
self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED)
ti.set_state(State.FAILED)
self.change_state(ti.key, State.FAILED)
+ ti._run_finished_callback() # pylint: disable=protected-access
continue
task_succeeded = self._run_task(ti)
@@ -77,9 +78,12 @@ class DebugExecutor(BaseExecutor):
params = self.tasks_params.pop(ti.key, {})
ti._run_raw_task(job_id=ti.job_id, **params) # pylint: disable=protected-access
self.change_state(key, State.SUCCESS)
+ ti._run_finished_callback() # pylint: disable=protected-access
return True
except Exception as e: # pylint: disable=broad-except
+ ti.set_state(State.FAILED)
self.change_state(key, State.FAILED)
+ ti._run_finished_callback() # pylint: disable=protected-access
self.log.exception("Failed to execute task: %s.", str(e))
return False
diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index da64b21..0d3d057 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -280,7 +280,7 @@ class BackfillJob(BaseJob):
"killed externally? Info: {}".format(ti, state, ti.state, info)
)
self.log.error(msg)
- ti.handle_failure(msg)
+ ti.handle_failure_with_callback(error=msg)
@provide_session
def _get_dag_run(self, run_date: datetime, dag: DAG, session: Session = None):
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index f4d4ef0..d68bfc7 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -98,7 +98,14 @@ class LocalTaskJob(BaseJob):
heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
- while True:
+ # task callback invocation happens either here or in
+ # self.heartbeat() instead of taskinstance._run_raw_task to
+ # avoid race conditions
+ #
+ # When self.terminating is set to True by heartbeat_callback, this
+ # loop should not be restarted. Otherwise self.handle_task_exit
+ # will be invoked and we will end up with duplicated callbacks
+ while not self.terminating:
# Monitor the task to see if it's done. Wait in a syscall
# (`os.wait`) for as long as possible so we notice the
# subprocess finishing as quick as we can
@@ -115,7 +122,7 @@ class LocalTaskJob(BaseJob):
return_code = self.task_runner.return_code(timeout=max_wait_time)
if return_code is not None:
- self.log.info("Task exited with return code %s", return_code)
+ self.handle_task_exit(return_code)
return
self.heartbeat()
@@ -134,6 +141,17 @@ class LocalTaskJob(BaseJob):
finally:
self.on_kill()
+ def handle_task_exit(self, return_code: int) -> None:
+ """Handle case where self.task_runner exits by itself"""
+ self.log.info("Task exited with return code %s", return_code)
+ self.task_instance.refresh_from_db()
+ # task exited by itself, so we need to check for error file
+ # incase it failed due to runtime exception/error
+ error = None
+ if self.task_instance.state != State.SUCCESS:
+ error = self.task_runner.deserialize_run_error()
+ self.task_instance._run_finished_callback(error=error) # pylint: disable=protected-access
+
def on_kill(self):
self.task_runner.terminate()
self.task_runner.on_finish()
@@ -169,11 +187,13 @@ class LocalTaskJob(BaseJob):
self.log.warning(
"State of this instance has been externally set to %s. " "Terminating instance.", ti.state
)
- if ti.state == State.FAILED and ti.task.on_failure_callback:
- context = ti.get_template_context()
- ti.task.on_failure_callback(context)
- if ti.state == State.SUCCESS and ti.task.on_success_callback:
- context = ti.get_template_context()
- ti.task.on_success_callback(context)
self.task_runner.terminate()
+ if ti.state == State.SUCCESS:
+ error = None
+ else:
+ # if ti.state is not set by taskinstance.handle_failure, then
+ # error file will not be populated and it must be updated by
+ # external source suck as web UI
+ error = self.task_runner.deserialize_run_error() or "task marked as failed externally"
+ ti._run_finished_callback(error=error) # pylint: disable=protected-access
self.terminating = True
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 3d58077..82b7561 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -590,7 +590,7 @@ class DagFileProcessor(LoggingMixin):
ti.state = simple_ti.state
ti.test_mode = self.UNIT_TEST_MODE
if request.is_failure_callback:
- ti.handle_failure(request.msg, ti.test_mode, ti.get_template_context())
+ ti.handle_failure_with_callback(error=request.msg, test_mode=ti.test_mode)
self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
@provide_session
@@ -1732,8 +1732,8 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
pools = models.Pool.slots_stats(session=session)
for pool_name, slot_stats in pools.items():
Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"])
- Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED])
- Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING])
+ Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED]) # type: ignore
+ Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING]) # type: ignore
@provide_session
def heartbeat_callback(self, session: Session = None) -> None:
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 1d0105d..d9096ff 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -584,7 +584,7 @@ class DAG(LoggingMixin):
next_run_date = None
if not date_last_automated_dagrun:
# First run
- task_start_dates = [t.start_date for t in self.tasks]
+ task_start_dates = [t.start_date for t in self.tasks if t.start_date]
if task_start_dates:
next_run_date = self.normalize_schedule(min(task_start_dates))
self.log.debug("Next run date based on tasks %s", next_run_date)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 203b5db..6b6f906 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -21,10 +21,12 @@ import hashlib
import logging
import math
import os
+import pickle
import signal
import warnings
from datetime import datetime, timedelta
-from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
+from tempfile import NamedTemporaryFile
+from typing import IO, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
from urllib.parse import quote
import dill
@@ -105,6 +107,29 @@ def set_current_context(context: Context):
)
+def load_error_file(fd: IO[bytes]) -> Optional[Union[str, Exception]]:
+ """Load and return error from error file"""
+ fd.seek(0, os.SEEK_SET)
+ data = fd.read()
+ if not data:
+ return None
+ try:
+ return pickle.loads(data)
+ except Exception: # pylint: disable=broad-except
+ return "Failed to load task run error"
+
+
+def set_error_file(error_file: str, error: Union[str, Exception]) -> None:
+ """Write error into error file by path"""
+ with open(error_file, "wb") as fd:
+ try:
+ pickle.dump(error, fd)
+ except Exception: # pylint: disable=broad-except
+ # local class objects cannot be pickled, so we fallback
+ # to store the string representation instead
+ pickle.dump(str(error), fd)
+
+
def clear_task_instances(
tis,
session,
@@ -1053,6 +1078,7 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
test_mode: bool = False,
job_id: Optional[str] = None,
pool: Optional[str] = None,
+ error_file: Optional[str] = None,
session=None,
) -> None:
"""
@@ -1111,7 +1137,7 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
return
except AirflowFailException as e:
self.refresh_from_db()
- self.handle_failure(e, test_mode, context, force_fail=True)
+ self.handle_failure(e, test_mode, force_fail=True, error_file=error_file)
raise
except AirflowException as e:
self.refresh_from_db()
@@ -1120,16 +1146,14 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
if self.state in {State.SUCCESS, State.FAILED}:
return
else:
- self.handle_failure(e, test_mode, context)
+ self.handle_failure(e, test_mode, error_file=error_file)
raise
except (Exception, KeyboardInterrupt) as e:
- self.handle_failure(e, test_mode, context)
+ self.handle_failure(e, test_mode, error_file=error_file)
raise
finally:
Stats.incr(f'ti.finish.{task.dag_id}.{task.task_id}.{self.state}')
- self._run_success_callback(context, task)
-
# Recording SUCCESS
self.end_date = timezone.utcnow()
self.log.info(
@@ -1275,16 +1299,6 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
# Raise exception for sensing state
raise AirflowSmartSensorException("Task successfully registered in smart sensor.")
- def _run_success_callback(self, context, task):
- """Functions that need to be run if Task is successful"""
- # Success callback
- try:
- if task.on_success_callback:
- task.on_success_callback(context)
- except Exception as exc: # pylint: disable=broad-except
- self.log.error("Failed when executing success callback")
- self.log.exception(exc)
-
def _execute_task(self, context, task_copy):
"""Executes Task (optionally with a Timeout) and pushes Xcom results"""
# If a timeout is specified for the task, make it fail
@@ -1303,7 +1317,7 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
self.xcom_push(key=XCOM_RETURN_KEY, value=result)
return result
- def _run_execute_callback(self, context, task):
+ def _run_execute_callback(self, context: Context, task):
"""Functions that need to be run before a Task is executed"""
try:
if task.on_execute_callback:
@@ -1312,6 +1326,31 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
self.log.error("Failed when executing execute callback")
self.log.exception(exc)
+ def _run_finished_callback(self, error: Optional[Union[str, Exception]] = None) -> None:
+ """
+ Call callback defined for finished state change.
+
+ NOTE: Only invoke this function from caller of self._run_raw_task or
+ self.run
+ """
+ if self.state == State.FAILED:
+ task = self.task
+ if task.on_failure_callback is not None:
+ context = self.get_template_context()
+ context["exception"] = error
+ task.on_failure_callback(context)
+ elif self.state == State.SUCCESS:
+ task = self.task
+ if task.on_success_callback is not None:
+ context = self.get_template_context()
+ task.on_success_callback(context)
+ elif self.state == State.UP_FOR_RETRY:
+ task = self.task
+ if task.on_retry_callback is not None:
+ context = self.get_template_context()
+ context["exception"] = error
+ task.on_retry_callback(context)
+
@provide_session
def run( # pylint: disable=too-many-arguments
self,
@@ -1339,10 +1378,23 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
pool=pool,
session=session,
)
- if res:
+ if not res:
+ return
+
+ try:
+ error_fd = NamedTemporaryFile(delete=True)
self._run_raw_task(
- mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session
+ mark_success=mark_success,
+ test_mode=test_mode,
+ job_id=job_id,
+ pool=pool,
+ error_file=error_fd.name,
+ session=session,
)
+ finally:
+ error = None if self.state == State.SUCCESS else load_error_file(error_fd)
+ error_fd.close()
+ self._run_finished_callback(error=error)
def dry_run(self):
"""Only Renders Templates for the TI"""
@@ -1386,14 +1438,25 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
self.log.info('Rescheduling task, marking task as UP_FOR_RESCHEDULE')
@provide_session
- def handle_failure(self, error, test_mode=None, context=None, force_fail=False, session=None):
+ def handle_failure(
+ self,
+ error: Union[str, Exception],
+ test_mode: Optional[bool] = None,
+ force_fail: bool = False,
+ error_file: Optional[str] = None,
+ session=None,
+ ) -> None:
"""Handle Failure for the TaskInstance"""
if test_mode is None:
test_mode = self.test_mode
- if context is None:
- context = self.get_template_context()
- self.log.exception(error)
+ if error:
+ self.log.exception(error)
+ # external monitoring process provides pickle file so _run_raw_task
+ # can send its runtime errors for access by failure callback
+ if error_file:
+ set_error_file(error_file, error)
+
task = self.task
self.end_date = timezone.utcnow()
self.set_duration()
@@ -1405,11 +1468,12 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
# Log failure duration
session.add(TaskFail(task, self.execution_date, self.start_date, self.end_date))
- if context is not None:
- context['exception'] = error
+ # Set state correctly and figure out how to log it and decide whether
+ # to email
- # Set state correctly and figure out how to log it,
- # what callback to call if any, and how to decide whether to email
+ # Note, callback invocation needs to be handled by caller of
+ # _run_raw_task to avoid race conditions which could lead to duplicate
+ # invocations or miss invocation.
# Since this function is called only when the TaskInstance state is running,
# try_number contains the current try_number (not the next). We
@@ -1423,12 +1487,10 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
else:
log_message = "Marking task as FAILED."
email_for_state = task.email_on_failure
- callback = task.on_failure_callback
else:
self.state = State.UP_FOR_RETRY
log_message = "Marking task as UP_FOR_RETRY."
email_for_state = task.email_on_retry
- callback = task.on_retry_callback
self.log.info(
'%s dag_id=%s, task_id=%s, execution_date=%s, start_date=%s, end_date=%s',
@@ -1446,18 +1508,21 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
self.log.error('Failed to send email to: %s', task.email)
self.log.exception(exec2)
- # Handling callbacks pessimistically
- if callback:
- try:
- callback(context)
- except Exception as exec3: # pylint: disable=broad-except
- self.log.error("Failed at executing callback")
- self.log.exception(exec3)
-
if not test_mode:
session.merge(self)
session.commit()
+ @provide_session
+ def handle_failure_with_callback(
+ self,
+ error: Union[str, Exception],
+ test_mode: Optional[bool] = None,
+ force_fail: bool = False,
+ session=None,
+ ) -> None:
+ self.handle_failure(error=error, test_mode=test_mode, force_fail=force_fail, session=session)
+ self._run_finished_callback(error=error)
+
def is_eligible_to_retry(self):
"""Is task instance is eligible for retry"""
return self.task.retries and self.try_number <= self.max_tries
diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py
index 743685e..81235ea 100644
--- a/airflow/task/task_runner/base_task_runner.py
+++ b/airflow/task/task_runner/base_task_runner.py
@@ -20,9 +20,12 @@ import getpass
import os
import subprocess
import threading
+from tempfile import NamedTemporaryFile
+from typing import Optional, Union
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
+from airflow.models.taskinstance import load_error_file
from airflow.utils.configuration import tmp_configuration_copy
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
@@ -81,17 +84,26 @@ class BaseTaskRunner(LoggingMixin):
# - the runner can read/execute those values as it needs
cfg_path = tmp_configuration_copy(chmod=0o600)
+ self._error_file = NamedTemporaryFile(delete=True)
self._cfg_path = cfg_path
- self._command = popen_prepend + self._task_instance.command_as_list(
- raw=True,
- pickle_id=local_task_job.pickle_id,
- mark_success=local_task_job.mark_success,
- job_id=local_task_job.id,
- pool=local_task_job.pool,
- cfg_path=cfg_path,
+ self._command = (
+ popen_prepend
+ + self._task_instance.command_as_list(
+ raw=True,
+ pickle_id=local_task_job.pickle_id,
+ mark_success=local_task_job.mark_success,
+ job_id=local_task_job.id,
+ pool=local_task_job.pool,
+ cfg_path=cfg_path,
+ )
+ + ["--error-file", self._error_file.name]
)
self.process = None
+ def deserialize_run_error(self) -> Optional[Union[str, Exception]]:
+ """Return task runtime error if its written to provided error file."""
+ return load_error_file(self._error_file)
+
def _read_task_logs(self, stream):
while True:
line = stream.readline()
@@ -144,7 +156,7 @@ class BaseTaskRunner(LoggingMixin):
"""Start running the task instance in a subprocess."""
raise NotImplementedError()
- def return_code(self):
+ def return_code(self) -> Optional[int]:
"""
:return: The return code associated with running the task instance or
None if the task is not yet done.
@@ -152,14 +164,15 @@ class BaseTaskRunner(LoggingMixin):
"""
raise NotImplementedError()
- def terminate(self):
- """Kill the running task instance."""
+ def terminate(self) -> None:
+ """Force kill the running task instance."""
raise NotImplementedError()
- def on_finish(self):
+ def on_finish(self) -> None:
"""A callback that should be called when this is done running."""
if self._cfg_path and os.path.isfile(self._cfg_path):
if self.run_as_user:
subprocess.call(['sudo', 'rm', self._cfg_path], close_fds=True)
else:
os.remove(self._cfg_path)
+ self._error_file.close()
diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py
index fee9b0d..505b225 100644
--- a/airflow/task/task_runner/standard_task_runner.py
+++ b/airflow/task/task_runner/standard_task_runner.py
@@ -18,6 +18,7 @@
"""Standard task runner"""
import logging
import os
+from typing import Optional
import psutil
from setproctitle import setproctitle # pylint: disable=no-name-in-module
@@ -91,7 +92,7 @@ class StandardTaskRunner(BaseTaskRunner):
logging.shutdown()
os._exit(return_code) # pylint: disable=protected-access
- def return_code(self, timeout=0):
+ def return_code(self, timeout: int = 0) -> Optional[int]:
# We call this multiple times, but we can only wait on the process once
if self._rc is not None or not self.process:
return self._rc
diff --git a/airflow/utils/process_utils.py b/airflow/utils/process_utils.py
index 1ee1ac6..38607bd 100644
--- a/airflow/utils/process_utils.py
+++ b/airflow/utils/process_utils.py
@@ -44,7 +44,12 @@ log = logging.getLogger(__name__)
DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM = conf.getint('core', 'KILLED_TASK_CLEANUP_TIME')
-def reap_process_group(pgid, logger, sig=signal.SIGTERM, timeout=DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM):
+def reap_process_group(
+ pgid: int,
+ logger,
+ sig: 'signal.Signals' = signal.SIGTERM,
+ timeout: int = DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM,
+) -> Dict[int, int]:
"""
Tries really hard to terminate all processes in the group (including grandchildren). Will send
sig (SIGTERM) to the process group of pid. If any process is alive after timeout
diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index 5073aa4..fae2c73 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -190,9 +190,9 @@ class TestCore(unittest.TestCase):
# Annoying workaround for nonlocal not existing in python 2
data = {'called': False}
- def check_failure(context, test_case=self):
+ def check_failure(context, test_case=self): # pylint: disable=unused-argument
data['called'] = True
- error = context.get('exception')
+ error = context.get("exception")
test_case.assertIsInstance(error, AirflowException)
op = BashOperator(
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index fdd0163..537a242 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -21,20 +21,22 @@ import os
import time
import unittest
import uuid
+from multiprocessing import Lock, Value
from unittest import mock
from unittest.mock import patch
import pytest
from airflow import settings
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowFailException
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.jobs.local_task_job import LocalTaskJob
from airflow.models.dag import DAG
from airflow.models.dagbag import DagBag
from airflow.models.taskinstance import TaskInstance
-from airflow.operators.dummy import DummyOperator
+from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python import PythonOperator
+from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
from airflow.utils import timezone
from airflow.utils.net import get_hostname
from airflow.utils.session import create_session
@@ -242,8 +244,6 @@ class TestLocalTaskJob(unittest.TestCase):
ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti_run.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor())
- from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
-
with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_method:
job1.run()
mock_method.assert_not_called()
@@ -286,8 +286,6 @@ class TestLocalTaskJob(unittest.TestCase):
return return_codes.pop(0)
time_start = time.time()
- from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
-
with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_start:
with patch.object(StandardTaskRunner, 'return_code') as mock_ret_code:
mock_ret_code.side_effect = multi_return_code
@@ -311,14 +309,18 @@ class TestLocalTaskJob(unittest.TestCase):
Test that ensures that mark_failure in the UI fails
the task, and executes on_failure_callback
"""
- data = {'called': False}
+ # use shared memory value so we can properly track value change even if
+ # it's been updated across processes.
+ failure_callback_called = Value('i', 0)
+ task_terminated_externally = Value('i', 1)
def check_failure(context):
+ with failure_callback_called.get_lock():
+ failure_callback_called.value += 1
assert context['dag_run'].dag_id == 'test_mark_failure'
- data['called'] = True
+ assert context['exception'] == "task marked as failed externally"
def task_function(ti):
- print("python_callable run in pid %s", os.getpid())
with create_session() as session:
assert State.RUNNING == ti.state
ti.log.info("Marking TI as failed 'externally'")
@@ -326,9 +328,10 @@ class TestLocalTaskJob(unittest.TestCase):
session.merge(ti)
session.commit()
- time.sleep(60)
+ time.sleep(10)
# This should not happen -- the state change should be noticed and the task should get killed
- data['reached_end_of_sleep'] = True
+ with task_terminated_externally.get_lock():
+ task_terminated_externally.value = 0
with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag:
task = PythonOperator(
@@ -337,16 +340,15 @@ class TestLocalTaskJob(unittest.TestCase):
on_failure_callback=check_failure,
)
- session = settings.Session()
-
dag.clear()
- dag.create_dagrun(
- run_id="test",
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session,
- )
+ with create_session() as session:
+ dag.create_dagrun(
+ run_id="test",
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
@@ -358,24 +360,106 @@ class TestLocalTaskJob(unittest.TestCase):
ti.refresh_from_db()
assert ti.state == State.FAILED
- assert data['called']
- assert 'reached_end_of_sleep' not in data, 'Task should not have been allowed to run to completion'
+ assert failure_callback_called.value == 1
+ assert task_terminated_externally.value == 1
+
+ @patch('airflow.utils.process_utils.subprocess.check_call')
+ @patch.object(StandardTaskRunner, 'return_code')
+ def test_failure_callback_only_called_once(self, mock_return_code, _check_call):
+ """
+ Test that ensures that when a task exits with failure by itself,
+ failure callback is only called once
+ """
+ # use shared memory value so we can properly track value change even if
+ # it's been updated across processes.
+ failure_callback_called = Value('i', 0)
+ callback_count_lock = Lock()
+
+ def failure_callback(context):
+ with callback_count_lock:
+ failure_callback_called.value += 1
+ assert context['dag_run'].dag_id == 'test_failure_callback_race'
+ assert isinstance(context['exception'], AirflowFailException)
+
+ def task_function(ti):
+ raise AirflowFailException()
+
+ dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE)
+ task = PythonOperator(
+ task_id='test_exit_on_failure',
+ python_callable=task_function,
+ on_failure_callback=failure_callback,
+ dag=dag,
+ )
+
+ dag.clear()
+ with create_session() as session:
+ dag.create_dagrun(
+ run_id="test",
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
+ ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+ ti.refresh_from_db()
+
+ job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+
+ # Simulate race condition where job1 heartbeat ran right after task
+ # state got set to failed by ti.handle_failure but before task process
+ # fully exits. See _execute loop in airflow/jobs/local_task_job.py.
+ # In this case, we have:
+ # * task_runner.return_code() is None
+ # * ti.state == State.Failed
+ #
+ # We also need to set return_code to a valid int after job1.terminating
+ # is set to True so _execute loop won't loop forever.
+ def dummy_return_code(*args, **kwargs):
+ return None if not job1.terminating else -9
+
+ mock_return_code.side_effect = dummy_return_code
+
+ with timeout(10):
+ # This should be _much_ shorter to run.
+ # If you change this limit, make the timeout in the callbable above bigger
+ job1.run()
+
+ ti.refresh_from_db()
+ assert ti.state == State.FAILED # task exits with failure state
+ assert failure_callback_called.value == 1
- @pytest.mark.quarantined
def test_mark_success_on_success_callback(self):
"""
Test that ensures that where a task is marked suceess in the UI
on_success_callback gets executed
"""
- data = {'called': False}
+ # use shared memory value so we can properly track value change even if
+ # it's been updated across processes.
+ success_callback_called = Value('i', 0)
+ task_terminated_externally = Value('i', 1)
+ shared_mem_lock = Lock()
def success_callback(context):
+ with shared_mem_lock:
+ success_callback_called.value += 1
assert context['dag_run'].dag_id == 'test_mark_success'
- data['called'] = True
dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
- task = DummyOperator(task_id='test_state_succeeded1', dag=dag, on_success_callback=success_callback)
+ def task_function(ti):
+ # pylint: disable=unused-argument
+ time.sleep(60)
+ # This should not happen -- the state change should be noticed and the task should get killed
+ with shared_mem_lock:
+ task_terminated_externally.value = 0
+
+ task = PythonOperator(
+ task_id='test_state_succeeded1',
+ python_callable=task_function,
+ on_success_callback=success_callback,
+ dag=dag,
+ )
session = settings.Session()
@@ -390,25 +474,25 @@ class TestLocalTaskJob(unittest.TestCase):
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
- from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
-
job1.task_runner = StandardTaskRunner(job1)
+
+ settings.engine.dispose()
process = multiprocessing.Process(target=job1.run)
process.start()
- ti.refresh_from_db()
- for _ in range(0, 50):
+
+ for _ in range(0, 25):
+ ti.refresh_from_db()
if ti.state == State.RUNNING:
break
- time.sleep(0.1)
- ti.refresh_from_db()
- assert State.RUNNING == ti.state
+ time.sleep(0.2)
+ assert ti.state == State.RUNNING
ti.state = State.SUCCESS
session.merge(ti)
session.commit()
- job1.heartbeat_callback(session=None)
- assert data['called']
process.join(timeout=10)
+ assert success_callback_called.value == 1
+ assert task_terminated_externally.value == 1
assert not process.is_alive()
@@ -436,5 +520,5 @@ class TestLocalTaskJobPerformance:
mock_get_task_runner.return_value.return_code.side_effects = return_codes
job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
- with assert_queries_count(12):
+ with assert_queries_count(13):
job.run()
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index c5340e8..ef6c52b 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -642,7 +642,7 @@ class TestDagFileProcessor(unittest.TestCase):
num_scheduled = scheduler._schedule_dag_run(dr3, {dr1.execution_date}, session)
assert num_scheduled == 0
- @patch.object(TaskInstance, 'handle_failure')
+ @patch.object(TaskInstance, 'handle_failure_with_callback')
def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
@@ -663,7 +663,8 @@ class TestDagFileProcessor(unittest.TestCase):
]
dag_file_processor.execute_callbacks(dagbag, requests)
mock_ti_handle_failure.assert_called_once_with(
- "Message", conf.getboolean('core', 'unit_test_mode'), mock.ANY
+ error="Message",
+ test_mode=conf.getboolean('core', 'unit_test_mode'),
)
def test_process_file_should_failure_callback(self):
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 012c547..cd99b02 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1389,8 +1389,9 @@ class TestTaskInstance(unittest.TestCase):
callback_wrapper.wrap_task_instance(ti)
ti._run_raw_task()
+ ti._run_finished_callback()
assert callback_wrapper.callback_ran
- assert callback_wrapper.task_state_in_callback == State.RUNNING
+ assert callback_wrapper.task_state_in_callback == State.SUCCESS
ti.refresh_from_db()
assert ti.state == State.SUCCESS
@@ -1618,6 +1619,7 @@ class TestTaskInstance(unittest.TestCase):
ti1 = TI(task=task1, execution_date=start_date)
ti1.state = State.FAILED
ti1.handle_failure("test failure handling")
+ ti1._run_finished_callback()
context_arg_1 = mock_on_failure_1.call_args[0][0]
assert context_arg_1 and "task_instance" in context_arg_1
@@ -1635,6 +1637,7 @@ class TestTaskInstance(unittest.TestCase):
ti2 = TI(task=task2, execution_date=start_date)
ti2.state = State.FAILED
ti2.handle_failure("test retry handling")
+ ti2._run_finished_callback()
mock_on_failure_2.assert_not_called()
@@ -1654,6 +1657,7 @@ class TestTaskInstance(unittest.TestCase):
ti3 = TI(task=task3, execution_date=start_date)
ti3.state = State.FAILED
ti3.handle_failure("test force_fail handling", force_fail=True)
+ ti3._run_finished_callback()
context_arg_3 = mock_on_failure_3.call_args[0][0]
assert context_arg_3 and "task_instance" in context_arg_3
diff --git a/tests/providers/apache/hive/transfers/test_mysql_to_hive.py b/tests/providers/apache/hive/transfers/test_mysql_to_hive.py
index b921c74..c6f7736 100644
--- a/tests/providers/apache/hive/transfers/test_mysql_to_hive.py
+++ b/tests/providers/apache/hive/transfers/test_mysql_to_hive.py
@@ -18,6 +18,7 @@
import unittest
from collections import OrderedDict
+from os import path
from unittest import mock
import pytest
@@ -35,6 +36,27 @@ DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
TEST_DAG_ID = 'unit_test_dag'
+class HiveopTempFile:
+ """
+ Make sure temp file path is in the format of "/tmp/airflow_hiveop_t_78lpye/tmpour2_kig",
+ """
+
+ def __eq__(self, other: str) -> bool:
+ (head, tail) = path.split(other)
+ (head, tail) = path.split(head)
+ return tail.startswith("airflow_hiveop_")
+
+
+class HiveopTempDir:
+ """
+ Make sure temp dir path is in the format of "/tmp/airflow_hiveop_t_78lpye",
+ """
+
+ def __eq__(self, other: str) -> bool:
+ (_, tail) = path.split(other)
+ return tail.startswith("airflow_hiveop_")
+
+
@pytest.mark.backend("mysql")
class TestTransfer(unittest.TestCase):
def setUp(self):
@@ -126,13 +148,10 @@ class TestTransfer(unittest.TestCase):
with MySqlHook().get_conn() as cur:
cur.execute("DROP TABLE IF EXISTS baby_names CASCADE;")
- @mock.patch('tempfile.tempdir', '/tmp/')
- @mock.patch('tempfile._RandomNameSequence.__next__')
@mock.patch('subprocess.Popen')
- def test_mysql_to_hive(self, mock_popen, mock_temp_dir):
+ def test_mysql_to_hive(self, mock_popen):
mock_subprocess = MockSubProcess()
mock_popen.return_value = mock_subprocess
- mock_temp_dir.return_value = "test_mysql_to_hive"
with mock.patch.dict('os.environ', self.env_vars):
sql = "SELECT * FROM baby_names LIMIT 1000;"
@@ -170,24 +189,21 @@ class TestTransfer(unittest.TestCase):
'-hiveconf',
'tez.queue.name=airflow',
'-f',
- '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive',
+ HiveopTempFile(),
]
mock_popen.assert_called_with(
hive_cmd,
stdout=mock_subprocess.PIPE,
stderr=mock_subprocess.STDOUT,
- cwd="/tmp/airflow_hiveop_test_mysql_to_hive",
+ cwd=HiveopTempDir(),
close_fds=True,
)
- @mock.patch('tempfile.tempdir', '/tmp/')
- @mock.patch('tempfile._RandomNameSequence.__next__')
@mock.patch('subprocess.Popen')
- def test_mysql_to_hive_partition(self, mock_popen, mock_temp_dir):
+ def test_mysql_to_hive_partition(self, mock_popen):
mock_subprocess = MockSubProcess()
mock_popen.return_value = mock_subprocess
- mock_temp_dir.return_value = "test_mysql_to_hive_partition"
with mock.patch.dict('os.environ', self.env_vars):
sql = "SELECT * FROM baby_names LIMIT 1000;"
@@ -227,24 +243,21 @@ class TestTransfer(unittest.TestCase):
'-hiveconf',
'tez.queue.name=airflow',
'-f',
- '/tmp/airflow_hiveop_test_mysql_to_hive_partition/tmptest_mysql_to_hive_partition',
+ HiveopTempFile(),
]
mock_popen.assert_called_with(
hive_cmd,
stdout=mock_subprocess.PIPE,
stderr=mock_subprocess.STDOUT,
- cwd="/tmp/airflow_hiveop_test_mysql_to_hive_partition",
+ cwd=HiveopTempDir(),
close_fds=True,
)
- @mock.patch('tempfile.tempdir', '/tmp/')
- @mock.patch('tempfile._RandomNameSequence.__next__')
@mock.patch('subprocess.Popen')
- def test_mysql_to_hive_tblproperties(self, mock_popen, mock_temp_dir):
+ def test_mysql_to_hive_tblproperties(self, mock_popen):
mock_subprocess = MockSubProcess()
mock_popen.return_value = mock_subprocess
- mock_temp_dir.return_value = "test_mysql_to_hive"
with mock.patch.dict('os.environ', self.env_vars):
sql = "SELECT * FROM baby_names LIMIT 1000;"
@@ -283,14 +296,14 @@ class TestTransfer(unittest.TestCase):
'-hiveconf',
'tez.queue.name=airflow',
'-f',
- '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive',
+ HiveopTempFile(),
]
mock_popen.assert_called_with(
hive_cmd,
stdout=mock_subprocess.PIPE,
stderr=mock_subprocess.STDOUT,
- cwd="/tmp/airflow_hiveop_test_mysql_to_hive",
+ cwd=HiveopTempDir(),
close_fds=True,
)
@@ -340,13 +353,10 @@ class TestTransfer(unittest.TestCase):
with hook.get_conn() as conn:
conn.execute(f"DROP TABLE IF EXISTS {mysql_table}")
- @mock.patch('tempfile.tempdir', '/tmp/')
- @mock.patch('tempfile._RandomNameSequence.__next__')
@mock.patch('subprocess.Popen')
- def test_mysql_to_hive_verify_csv_special_char(self, mock_popen, mock_temp_dir):
+ def test_mysql_to_hive_verify_csv_special_char(self, mock_popen):
mock_subprocess = MockSubProcess()
mock_popen.return_value = mock_subprocess
- mock_temp_dir.return_value = "test_mysql_to_hive"
mysql_table = 'test_mysql_to_hive'
hive_table = 'test_mysql_to_hive'
@@ -424,27 +434,24 @@ class TestTransfer(unittest.TestCase):
'-hiveconf',
'tez.queue.name=airflow',
'-f',
- '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive',
+ HiveopTempFile(),
]
mock_popen.assert_called_with(
hive_cmd,
stdout=mock_subprocess.PIPE,
stderr=mock_subprocess.STDOUT,
- cwd="/tmp/airflow_hiveop_test_mysql_to_hive",
+ cwd=HiveopTempDir(),
close_fds=True,
)
finally:
with hook.get_conn() as conn:
conn.execute(f"DROP TABLE IF EXISTS {mysql_table}")
- @mock.patch('tempfile.tempdir', '/tmp/')
- @mock.patch('tempfile._RandomNameSequence.__next__')
@mock.patch('subprocess.Popen')
- def test_mysql_to_hive_verify_loaded_values(self, mock_popen, mock_temp_dir):
+ def test_mysql_to_hive_verify_loaded_values(self, mock_popen):
mock_subprocess = MockSubProcess()
mock_popen.return_value = mock_subprocess
- mock_temp_dir.return_value = "test_mysql_to_hive"
mysql_table = 'test_mysql_to_hive'
hive_table = 'test_mysql_to_hive'
@@ -537,14 +544,14 @@ class TestTransfer(unittest.TestCase):
'-hiveconf',
'tez.queue.name=airflow',
'-f',
- '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive',
+ HiveopTempFile(),
]
mock_popen.assert_called_with(
hive_cmd,
stdout=mock_subprocess.PIPE,
stderr=mock_subprocess.STDOUT,
- cwd="/tmp/airflow_hiveop_test_mysql_to_hive",
+ cwd=HiveopTempDir(),
close_fds=True,
)