You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/03/22 19:16:37 UTC
[airflow] 26/31: Fix race condition between triggerer and scheduler (#21316)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 8bb8d2571a3a988bf05696e9835b99339a9b5f09
Author: Malthe Borch <mb...@gmail.com>
AuthorDate: Tue Feb 15 13:12:51 2022 +0000
Fix race condition between triggerer and scheduler (#21316)
(cherry picked from commit 2a6792d94d153c6f2dd116843a43ee63cd296c8d)
---
airflow/executors/base_executor.py | 36 ++++++++++++++++++---
tests/executors/test_base_executor.py | 60 ++++++++++++++++++++++++++++++++---
2 files changed, 87 insertions(+), 9 deletions(-)
diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py
index f7ad45a..1d993bb 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -17,7 +17,7 @@
"""Base executor - this is the base class for all the implemented executors."""
import sys
from collections import OrderedDict
-from typing import Any, Dict, List, Optional, Set, Tuple
+from typing import Any, Counter, Dict, List, Optional, Set, Tuple
from airflow.configuration import conf
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
@@ -29,6 +29,8 @@ PARALLELISM: int = conf.getint('core', 'PARALLELISM')
NOT_STARTED_MESSAGE = "The executor should be started first!"
+QUEUEING_ATTEMPTS = 5
+
# Command to execute - list of strings
# the first element is always "airflow".
# It should be result of TaskInstance.generate_command method.q
@@ -63,6 +65,7 @@ class BaseExecutor(LoggingMixin):
self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] = OrderedDict()
self.running: Set[TaskInstanceKey] = set()
self.event_buffer: Dict[TaskInstanceKey, EventBufferValueType] = {}
+ self.attempts: Counter[TaskInstanceKey] = Counter()
def __repr__(self):
return f"{self.__class__.__name__}(parallelism={self.parallelism})"
@@ -78,7 +81,7 @@ class BaseExecutor(LoggingMixin):
queue: Optional[str] = None,
):
"""Queues command to task"""
- if task_instance.key not in self.queued_tasks and task_instance.key not in self.running:
+ if task_instance.key not in self.queued_tasks:
self.log.info("Adding to queue: %s", command)
self.queued_tasks[task_instance.key] = (command, priority, queue, task_instance)
else:
@@ -183,9 +186,32 @@ class BaseExecutor(LoggingMixin):
for _ in range(min((open_slots, len(self.queued_tasks)))):
key, (command, _, queue, ti) = sorted_queue.pop(0)
- self.queued_tasks.pop(key)
- self.running.add(key)
- self.execute_async(key=key, command=command, queue=queue, executor_config=ti.executor_config)
+
+ # If a task makes it here but is still understood by the executor
+ # to be running, it generally means that the task has been killed
+ # externally and not yet been marked as failed.
+ #
+ # However, when a task is deferred, there is also a possibility of
+ # a race condition where a task might be scheduled again during
+ # trigger processing, even before we are able to register that the
+ # deferred task has completed. In this case and for this reason,
+ # we make a small number of attempts to see if the task has been
+ # removed from the running set in the meantime.
+ if key in self.running:
+ attempt = self.attempts[key]
+ if attempt < QUEUEING_ATTEMPTS - 1:
+ self.attempts[key] = attempt + 1
+ self.log.info("task %s is still running", key)
+ continue
+
+ # We give up and remove the task from the queue.
+ self.log.error("could not queue task %s (still running after %d attempts)", key, attempt)
+ del self.attempts[key]
+ del self.queued_tasks[key]
+ else:
+ del self.queued_tasks[key]
+ self.running.add(key)
+ self.execute_async(key=key, command=command, queue=queue, executor_config=ti.executor_config)
def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
"""
diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py
index 49d6c01..40bf8eb 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -18,7 +18,9 @@
from datetime import timedelta
from unittest import mock
-from airflow.executors.base_executor import BaseExecutor
+from pytest import mark
+
+from airflow.executors.base_executor import QUEUEING_ATTEMPTS, BaseExecutor
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils import timezone
@@ -57,7 +59,7 @@ def test_gauge_executor_metrics(mock_stats_gauge, mock_trigger_tasks, mock_sync)
mock_stats_gauge.assert_has_calls(calls)
-def test_try_adopt_task_instances(dag_maker):
+def setup_dagrun(dag_maker):
date = timezone.utcnow()
start_date = date - timedelta(days=2)
@@ -66,8 +68,58 @@ def test_try_adopt_task_instances(dag_maker):
BaseOperator(task_id="task_2", start_date=start_date)
BaseOperator(task_id="task_3", start_date=start_date)
- dagrun = dag_maker.create_dagrun(execution_date=date)
- tis = dagrun.task_instances
+ return dag_maker.create_dagrun(execution_date=date)
+
+def test_try_adopt_task_instances(dag_maker):
+ dagrun = setup_dagrun(dag_maker)
+ tis = dagrun.task_instances
assert {ti.task_id for ti in tis} == {"task_1", "task_2", "task_3"}
assert BaseExecutor().try_adopt_task_instances(tis) == tis
+
+
+def enqueue_tasks(executor, dagrun):
+ for task_instance in dagrun.task_instances:
+ executor.queue_command(task_instance, ["airflow"])
+
+
+def setup_trigger_tasks(dag_maker):
+ dagrun = setup_dagrun(dag_maker)
+ executor = BaseExecutor()
+ executor.execute_async = mock.Mock()
+ enqueue_tasks(executor, dagrun)
+ return executor, dagrun
+
+
+@mark.parametrize("open_slots", [1, 2, 3])
+def test_trigger_queued_tasks(dag_maker, open_slots):
+ executor, _ = setup_trigger_tasks(dag_maker)
+ executor.trigger_tasks(open_slots)
+ assert len(executor.execute_async.mock_calls) == open_slots
+
+
+@mark.parametrize("change_state_attempt", range(QUEUEING_ATTEMPTS + 2))
+def test_trigger_running_tasks(dag_maker, change_state_attempt):
+ executor, dagrun = setup_trigger_tasks(dag_maker)
+ open_slots = 100
+ executor.trigger_tasks(open_slots)
+ expected_calls = len(dagrun.task_instances) # initially `execute_async` called for each task
+ assert len(executor.execute_async.mock_calls) == expected_calls
+
+ # All the tasks are now "running", so while we enqueue them again here,
+ # they won't be executed again until the executor has been notified of a state change.
+ enqueue_tasks(executor, dagrun)
+
+ for attempt in range(QUEUEING_ATTEMPTS + 2):
+ # On the configured attempt, we notify the executor that the task has succeeded.
+ if attempt == change_state_attempt:
+ executor.change_state(dagrun.task_instances[0].key, State.SUCCESS)
+ # If we have not exceeded QUEUEING_ATTEMPTS, we should expect an additional "execute" call
+ if attempt < QUEUEING_ATTEMPTS:
+ expected_calls += 1
+ executor.trigger_tasks(open_slots)
+ assert len(executor.execute_async.mock_calls) == expected_calls
+ if change_state_attempt < QUEUEING_ATTEMPTS:
+ assert len(executor.execute_async.mock_calls) == len(dagrun.task_instances) + 1
+ else:
+ assert len(executor.execute_async.mock_calls) == len(dagrun.task_instances)