You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/03/22 14:37:32 UTC

[airflow] 26/28: Fix race condition between triggerer and scheduler (#21316)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 21bcb055b6214687b740ae12d9ba4e645af67686
Author: Malthe Borch <mb...@gmail.com>
AuthorDate: Tue Feb 15 13:12:51 2022 +0000

    Fix race condition between triggerer and scheduler (#21316)
    
    (cherry picked from commit 2a6792d94d153c6f2dd116843a43ee63cd296c8d)
---
 airflow/executors/base_executor.py    | 36 ++++++++++++++++++---
 tests/executors/test_base_executor.py | 60 ++++++++++++++++++++++++++++++++---
 2 files changed, 87 insertions(+), 9 deletions(-)

diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py
index f7ad45a..1d993bb 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -17,7 +17,7 @@
 """Base executor - this is the base class for all the implemented executors."""
 import sys
 from collections import OrderedDict
-from typing import Any, Dict, List, Optional, Set, Tuple
+from typing import Any, Counter, Dict, List, Optional, Set, Tuple
 
 from airflow.configuration import conf
 from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
@@ -29,6 +29,8 @@ PARALLELISM: int = conf.getint('core', 'PARALLELISM')
 
 NOT_STARTED_MESSAGE = "The executor should be started first!"
 
+QUEUEING_ATTEMPTS = 5
+
 # Command to execute - list of strings
 # the first element is always "airflow".
 # It should be result of TaskInstance.generate_command method.q
@@ -63,6 +65,7 @@ class BaseExecutor(LoggingMixin):
         self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] = OrderedDict()
         self.running: Set[TaskInstanceKey] = set()
         self.event_buffer: Dict[TaskInstanceKey, EventBufferValueType] = {}
+        self.attempts: Counter[TaskInstanceKey] = Counter()
 
     def __repr__(self):
         return f"{self.__class__.__name__}(parallelism={self.parallelism})"
@@ -78,7 +81,7 @@ class BaseExecutor(LoggingMixin):
         queue: Optional[str] = None,
     ):
         """Queues command to task"""
-        if task_instance.key not in self.queued_tasks and task_instance.key not in self.running:
+        if task_instance.key not in self.queued_tasks:
             self.log.info("Adding to queue: %s", command)
             self.queued_tasks[task_instance.key] = (command, priority, queue, task_instance)
         else:
@@ -183,9 +186,32 @@ class BaseExecutor(LoggingMixin):
 
         for _ in range(min((open_slots, len(self.queued_tasks)))):
             key, (command, _, queue, ti) = sorted_queue.pop(0)
-            self.queued_tasks.pop(key)
-            self.running.add(key)
-            self.execute_async(key=key, command=command, queue=queue, executor_config=ti.executor_config)
+
+            # If a task makes it here but is still understood by the executor
+            # to be running, it generally means that the task has been killed
+            # externally and not yet been marked as failed.
+            #
+            # However, when a task is deferred, there is also a possibility of
+            # a race condition where a task might be scheduled again during
+            # trigger processing, even before we are able to register that the
+            # deferred task has completed. In this case and for this reason,
+            # we make a small number of attempts to see if the task has been
+            # removed from the running set in the meantime.
+            if key in self.running:
+                attempt = self.attempts[key]
+                if attempt < QUEUEING_ATTEMPTS - 1:
+                    self.attempts[key] = attempt + 1
+                    self.log.info("task %s is still running", key)
+                    continue
+
+                # We give up and remove the task from the queue.
+                self.log.error("could not queue task %s (still running after %d attempts)", key, attempt)
+                del self.attempts[key]
+                del self.queued_tasks[key]
+            else:
+                del self.queued_tasks[key]
+                self.running.add(key)
+                self.execute_async(key=key, command=command, queue=queue, executor_config=ti.executor_config)
 
     def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
         """
diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py
index 49d6c01..40bf8eb 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -18,7 +18,9 @@
 from datetime import timedelta
 from unittest import mock
 
-from airflow.executors.base_executor import BaseExecutor
+from pytest import mark
+
+from airflow.executors.base_executor import QUEUEING_ATTEMPTS, BaseExecutor
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.taskinstance import TaskInstanceKey
 from airflow.utils import timezone
@@ -57,7 +59,7 @@ def test_gauge_executor_metrics(mock_stats_gauge, mock_trigger_tasks, mock_sync)
     mock_stats_gauge.assert_has_calls(calls)
 
 
-def test_try_adopt_task_instances(dag_maker):
+def setup_dagrun(dag_maker):
     date = timezone.utcnow()
     start_date = date - timedelta(days=2)
 
@@ -66,8 +68,58 @@ def test_try_adopt_task_instances(dag_maker):
         BaseOperator(task_id="task_2", start_date=start_date)
         BaseOperator(task_id="task_3", start_date=start_date)
 
-    dagrun = dag_maker.create_dagrun(execution_date=date)
-    tis = dagrun.task_instances
+    return dag_maker.create_dagrun(execution_date=date)
 
+
+def test_try_adopt_task_instances(dag_maker):
+    dagrun = setup_dagrun(dag_maker)
+    tis = dagrun.task_instances
     assert {ti.task_id for ti in tis} == {"task_1", "task_2", "task_3"}
     assert BaseExecutor().try_adopt_task_instances(tis) == tis
+
+
+def enqueue_tasks(executor, dagrun):
+    for task_instance in dagrun.task_instances:
+        executor.queue_command(task_instance, ["airflow"])
+
+
+def setup_trigger_tasks(dag_maker):
+    dagrun = setup_dagrun(dag_maker)
+    executor = BaseExecutor()
+    executor.execute_async = mock.Mock()
+    enqueue_tasks(executor, dagrun)
+    return executor, dagrun
+
+
+@mark.parametrize("open_slots", [1, 2, 3])
+def test_trigger_queued_tasks(dag_maker, open_slots):
+    executor, _ = setup_trigger_tasks(dag_maker)
+    executor.trigger_tasks(open_slots)
+    assert len(executor.execute_async.mock_calls) == open_slots
+
+
+@mark.parametrize("change_state_attempt", range(QUEUEING_ATTEMPTS + 2))
+def test_trigger_running_tasks(dag_maker, change_state_attempt):
+    executor, dagrun = setup_trigger_tasks(dag_maker)
+    open_slots = 100
+    executor.trigger_tasks(open_slots)
+    expected_calls = len(dagrun.task_instances)  # initially `execute_async` called for each task
+    assert len(executor.execute_async.mock_calls) == expected_calls
+
+    # All the tasks are now "running", so while we enqueue them again here,
+    # they won't be executed again until the executor has been notified of a state change.
+    enqueue_tasks(executor, dagrun)
+
+    for attempt in range(QUEUEING_ATTEMPTS + 2):
+        # On the configured attempt, we notify the executor that the task has succeeded.
+        if attempt == change_state_attempt:
+            executor.change_state(dagrun.task_instances[0].key, State.SUCCESS)
+            # If we have not exceeded QUEUEING_ATTEMPTS, we should expect an additional "execute" call
+            if attempt < QUEUEING_ATTEMPTS:
+                expected_calls += 1
+        executor.trigger_tasks(open_slots)
+        assert len(executor.execute_async.mock_calls) == expected_calls
+    if change_state_attempt < QUEUEING_ATTEMPTS:
+        assert len(executor.execute_async.mock_calls) == len(dagrun.task_instances) + 1
+    else:
+        assert len(executor.execute_async.mock_calls) == len(dagrun.task_instances)