You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2021/09/15 13:37:03 UTC

[airflow] branch main updated: Sort adopted tasks in _check_for_stalled_adopted_tasks method (#18208)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9a7243a  Sort adopted tasks in _check_for_stalled_adopted_tasks method (#18208)
9a7243a is described below

commit 9a7243adb8ec4d3d9185bad74da22e861582ffbe
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Sep 15 14:36:45 2021 +0100

    Sort adopted tasks in _check_for_stalled_adopted_tasks method (#18208)
    
    This PR adds sorting in adopted_tasks_timeout to ensure we correctly
    clear stalled adopted tasks
---
 airflow/executors/celery_executor.py    |  4 +++-
 tests/executors/test_celery_executor.py | 28 ++++++++++++++++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index 40a94f8..f257b0c 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -347,8 +347,10 @@ class CeleryExecutor(BaseExecutor):
         """
         now = utcnow()
 
+        sorted_adopted_task_timeouts = sorted(self.adopted_task_timeouts.items(), key=lambda k: k[1])
+
         timedout_keys = []
-        for key, stalled_after in self.adopted_task_timeouts.items():
+        for key, stalled_after in sorted_adopted_task_timeouts:
             if stalled_after > now:
                 # Since items are stored sorted, if we get to a stalled_after
                 # in the future then we can stop
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index 636d49d..db63b18 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -383,6 +383,34 @@ class TestCeleryExecutor(unittest.TestCase):
         assert executor.running == set()
         assert executor.adopted_task_timeouts == {}
 
+    @pytest.mark.backend("mysql", "postgres")
+    def test_check_for_stalled_adopted_tasks_goes_in_ordered_fashion(self):
+        start_date = timezone.utcnow() - timedelta(days=2)
+        queued_dttm = timezone.utcnow() - timedelta(minutes=30)
+        queued_dttm_2 = timezone.utcnow() - timedelta(minutes=4)
+
+        try_number = 1
+
+        with DAG("test_check_for_stalled_adopted_tasks") as dag:
+            task_1 = BaseOperator(task_id="task_1", start_date=start_date)
+            task_2 = BaseOperator(task_id="task_2", start_date=start_date)
+
+        key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, "runid", try_number)
+        key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, "runid", try_number)
+
+        executor = celery_executor.CeleryExecutor()
+        executor.adopted_task_timeouts = {
+            key_2: queued_dttm_2 + executor.task_adoption_timeout,
+            key_1: queued_dttm + executor.task_adoption_timeout,
+        }
+        executor.running = {key_1, key_2}
+        executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")}
+        executor.sync()
+        assert executor.event_buffer == {key_1: (State.FAILED, None)}
+        assert executor.tasks == {key_2: AsyncResult('232')}
+        assert executor.running == {key_2}
+        assert executor.adopted_task_timeouts == {key_2: queued_dttm_2 + executor.task_adoption_timeout}
+
 
 def test_operation_timeout_config():
     assert celery_executor.OPERATION_TIMEOUT == 1