You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/19 15:06:07 UTC

[airflow] 05/42: Speed up clear_task_instances by doing a single sql delete for TaskReschedule (#14048)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 118f86c394f97fa628fc2069bb31ba29d70e37d8
Author: yuqian90 <yu...@gmail.com>
AuthorDate: Wed Feb 10 22:07:18 2021 +0800

    Speed up clear_task_instances by doing a single sql delete for TaskReschedule (#14048)
    
    Clearing large number of tasks takes a long time. Most of the time is spent at this line in clear_task_instances (more than 95% time). This slowness sometimes causes the webserver to timeout because the web_server_worker_timeout is hit.
    
    ```
            # Clear all reschedules related to the ti to clear
            session.query(TR).filter(
                TR.dag_id == ti.dag_id,
                TR.task_id == ti.task_id,
                TR.execution_date == ti.execution_date,
                TR.try_number == ti.try_number,
            ).delete()
    ```
    This line was very slow because it's deleting TaskReschedule rows in a for loop one by one.
    
    This PR simply changes this code to delete TaskReschedule in a single sql query with a bunch of OR conditions. It's effectively doing the same, but now it's much faster.
    
    Some profiling showed great speed improvement (something like 40 to 50 times faster) compared to the first iteration. So the overall performance should now be 300 times faster than the original for loop deletion.
    
    (cherry picked from commit 9036ce20c140520d3f9d5e0f83b5ebfded07fa7c)
---
 airflow/models/taskinstance.py  | 37 ++++++++++++++++++++++++++------
 tests/models/test_cleartasks.py | 47 ++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 77 insertions(+), 7 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index d671a01..c7d7ff7 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -24,6 +24,7 @@ import os
 import pickle
 import signal
 import warnings
+from collections import defaultdict
 from datetime import datetime, timedelta
 from tempfile import NamedTemporaryFile
 from typing import IO, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
@@ -146,6 +147,7 @@ def clear_task_instances(
     :param dag: DAG object
     """
     job_ids = []
+    task_id_by_key = defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
     for ti in tis:
         if ti.state == State.RUNNING:
             if ti.job_id:
@@ -166,13 +168,36 @@ def clear_task_instances(
                 ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries)
             ti.state = State.NONE
             session.merge(ti)
+
+        task_id_by_key[ti.dag_id][ti.execution_date][ti.try_number].add(ti.task_id)
+
+    if task_id_by_key:
         # Clear all reschedules related to the ti to clear
-        session.query(TR).filter(
-            TR.dag_id == ti.dag_id,
-            TR.task_id == ti.task_id,
-            TR.execution_date == ti.execution_date,
-            TR.try_number == ti.try_number,
-        ).delete()
+
+        # This is an optimization for the common case where all tis are for a small number
+        # of dag_id, execution_date and try_number. Use a nested dict of dag_id,
+        # execution_date, try_number and task_id to construct the where clause in a
+        # hierarchical manner. This speeds up the delete statement by more than 40x for
+        # large number of tis (50k+).
+        conditions = or_(
+            and_(
+                TR.dag_id == dag_id,
+                or_(
+                    and_(
+                        TR.execution_date == execution_date,
+                        or_(
+                            and_(TR.try_number == try_number, TR.task_id.in_(task_ids))
+                            for try_number, task_ids in task_tries.items()
+                        ),
+                    )
+                    for execution_date, task_tries in dates.items()
+                ),
+            )
+            for dag_id, dates in task_id_by_key.items()
+        )
+
+        delete_qry = TR.__table__.delete().where(conditions)
+        session.execute(delete_qry)
 
     if job_ids:
         from airflow.jobs.base_job import BaseJob
diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py
index f54bacc..1c5606e 100644
--- a/tests/models/test_cleartasks.py
+++ b/tests/models/test_cleartasks.py
@@ -20,8 +20,9 @@ import datetime
 import unittest
 
 from airflow import settings
-from airflow.models import DAG, TaskInstance as TI, clear_task_instances
+from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances
 from airflow.operators.dummy import DummyOperator
+from airflow.sensors.python import PythonSensor
 from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.types import DagRunType
@@ -138,6 +139,50 @@ class TestClearTasks(unittest.TestCase):
         assert ti1.try_number == 2
         assert ti1.max_tries == 2
 
+    def test_clear_task_instances_with_task_reschedule(self):
+        """Test that TaskReschedules are deleted correctly when TaskInstances are cleared"""
+
+        with DAG(
+            'test_clear_task_instances_with_task_reschedule',
+            start_date=DEFAULT_DATE,
+            end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+        ) as dag:
+            task0 = PythonSensor(task_id='0', python_callable=lambda: False, mode="reschedule")
+            task1 = PythonSensor(task_id='1', python_callable=lambda: False, mode="reschedule")
+
+        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
+        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
+
+        dag.create_dagrun(
+            execution_date=ti0.execution_date,
+            state=State.RUNNING,
+            run_type=DagRunType.SCHEDULED,
+        )
+
+        ti0.run()
+        ti1.run()
+
+        with create_session() as session:
+
+            def count_task_reschedule(task_id):
+                return (
+                    session.query(TaskReschedule)
+                    .filter(
+                        TaskReschedule.dag_id == dag.dag_id,
+                        TaskReschedule.task_id == task_id,
+                        TaskReschedule.execution_date == DEFAULT_DATE,
+                        TaskReschedule.try_number == 1,
+                    )
+                    .count()
+                )
+
+            assert count_task_reschedule(ti0.task_id) == 1
+            assert count_task_reschedule(ti1.task_id) == 1
+            qry = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.task_id == ti0.task_id).all()
+            clear_task_instances(qry, session, dag=dag)
+            assert count_task_reschedule(ti0.task_id) == 0
+            assert count_task_reschedule(ti1.task_id) == 1
+
     def test_dag_clear(self):
         dag = DAG(
             'test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)