You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/03/22 14:37:29 UTC

[airflow] 23/28: Fix duplicate trigger creation race condition (#20699)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 6cd987078090cbadec28aa325800dc6447c4132d
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Thu Jan 6 15:16:02 2022 -0800

    Fix duplicate trigger creation race condition (#20699)
    
    The process for queueing up a trigger, for execution by the TriggerRunner, is handled by the TriggerJob's `load_triggers` method.  It fetches the triggers that should be running according to the database, checks if they are running and if not it adds them to `TriggerRunner.to_create`.  The problem is tha there's a small window of time between the moment a trigger (upon termination) is purged from the `TriggerRunner.triggers` set,  and the time that the database is updated to reflect t [...]
    
    To resolve this what we do here is, before adding a trigger to the `to_create` queue, instead of comparing against the "running" triggers, we compare against all triggers known to the TriggerRunner instance.  When triggers move out of the `triggers` set they move into other data structures such as `events` and `failed_triggers` and `to_cancel`.  So we union all of these and only create those triggers which the database indicates should exist _and_ which are know already being handled  [...]
    
    (cherry picked from commit 16b8c476518ed76e3689966ec4b0b788be935410)
---
 airflow/jobs/triggerer_job.py    |  12 +++-
 tests/jobs/test_triggerer_job.py | 136 ++++++++++++++++++++++++++++++++++++++-
 2 files changed, 142 insertions(+), 6 deletions(-)

diff --git a/airflow/jobs/triggerer_job.py b/airflow/jobs/triggerer_job.py
index 25a4c79..dff0e0f 100644
--- a/airflow/jobs/triggerer_job.py
+++ b/airflow/jobs/triggerer_job.py
@@ -381,10 +381,16 @@ class TriggerRunner(threading.Thread, LoggingMixin):
         # line's execution, but we consider that safe, since there's a strict
         # add -> remove -> never again lifecycle this function is already
         # handling.
-        current_trigger_ids = set(self.triggers.keys())
+        running_trigger_ids = set(self.triggers.keys())
+        known_trigger_ids = (
+            running_trigger_ids.union(x[0] for x in self.events)
+            .union(self.to_cancel)
+            .union(x[0] for x in self.to_create)
+            .union(self.failed_triggers)
+        )
         # Work out the two difference sets
-        new_trigger_ids = requested_trigger_ids.difference(current_trigger_ids)
-        cancel_trigger_ids = current_trigger_ids.difference(requested_trigger_ids)
+        new_trigger_ids = requested_trigger_ids - known_trigger_ids
+        cancel_trigger_ids = running_trigger_ids - requested_trigger_ids
         # Bulk-fetch new trigger records
         new_triggers = Trigger.bulk_fetch(new_trigger_ids)
         # Add in new triggers
diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py
index 5adc91f..870116a 100644
--- a/tests/jobs/test_triggerer_job.py
+++ b/tests/jobs/test_triggerer_job.py
@@ -16,29 +16,53 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import asyncio
 import datetime
 import sys
 import time
+from threading import Thread
 
 import pytest
 
-from airflow.jobs.triggerer_job import TriggererJob
-from airflow.models import Trigger
+from airflow.jobs.triggerer_job import TriggererJob, TriggerRunner
+from airflow.models import DagModel, DagRun, TaskInstance, Trigger
 from airflow.operators.dummy import DummyOperator
+from airflow.operators.python import PythonOperator
 from airflow.triggers.base import TriggerEvent
 from airflow.triggers.temporal import TimeDeltaTrigger
 from airflow.triggers.testing import FailureTrigger, SuccessTrigger
 from airflow.utils import timezone
 from airflow.utils.session import create_session
 from airflow.utils.state import State, TaskInstanceState
-from tests.test_utils.db import clear_db_runs
+from tests.test_utils.db import clear_db_dags, clear_db_runs
+
+
+class TimeDeltaTrigger_(TimeDeltaTrigger):
+    def __init__(self, delta, filename):
+        super().__init__(delta=delta)
+        self.filename = filename
+        self.delta = delta
+
+    async def run(self):
+        with open(self.filename, 'at') as f:
+            f.write('hi\n')
+        async for event in super().run():
+            yield event
+
+    def serialize(self):
+        return (
+            "tests.jobs.test_triggerer_job.TimeDeltaTrigger_",
+            {"delta": self.delta, "filename": self.filename},
+        )
 
 
 @pytest.fixture(autouse=True)
 def clean_database():
     """Fixture that cleans the database before and after every test."""
     clear_db_runs()
+    clear_db_dags()
     yield  # Test runs here
+    clear_db_dags()
     clear_db_runs()
 
 
@@ -160,6 +184,112 @@ def test_trigger_lifecycle(session):
 
 
 @pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6")
+def test_trigger_create_race_condition_18392(session, tmp_path):
+    """
+    This verifies the resolution of race condition documented in github issue #18392.
+    Triggers are queued for creation by TriggerJob.load_triggers.
+    There was a race condition where multiple triggers would be created unnecessarily.
+    What happens is the runner completes the trigger and purges from the "running" list.
+    Then job.load_triggers is called and it looks like the trigger is not running but should,
+    so it queues it again.
+
+    The scenario is as follows:
+        1. job.load_triggers (trigger now queued)
+        2. runner.create_triggers (trigger now running)
+        3. job.handle_events (trigger still appears running so state not updated in DB)
+        4. runner.cleanup_finished_triggers (trigger completed at this point; trigger from "running" set)
+        5. job.load_triggers (trigger not running, but also not purged from DB, so it is queued again)
+        6. runner.create_triggers (trigger created again)
+
+    This test verifies that under this scenario only one trigger is created.
+    """
+    path = tmp_path / 'test_trigger_bad_respawn.txt'
+
+    class TriggerRunner_(TriggerRunner):
+        """We do some waiting for main thread looping"""
+
+        async def wait_for_job_method_count(self, method, count):
+            for _ in range(30):
+                await asyncio.sleep(0.1)
+                if getattr(self, f'{method}_count', 0) >= count:
+                    break
+            else:
+                pytest.fail(f"did not observe count {count} in job method {method}")
+
+        async def create_triggers(self):
+            """
+            On first run, wait for job.load_triggers to make sure they are queued
+            """
+            if getattr(self, 'loop_count', 0) == 0:
+                await self.wait_for_job_method_count('load_triggers', 1)
+            await super().create_triggers()
+            self.loop_count = getattr(self, 'loop_count', 0) + 1
+
+        async def cleanup_finished_triggers(self):
+            """On loop 1, make sure that job.handle_events was already called"""
+            if self.loop_count == 1:
+                await self.wait_for_job_method_count('handle_events', 1)
+            await super().cleanup_finished_triggers()
+
+    class TriggererJob_(TriggererJob):
+        """We do some waiting for runner thread looping (and track calls in job thread)"""
+
+        def wait_for_runner_loop(self, runner_loop_count):
+            for _ in range(30):
+                time.sleep(0.1)
+                if getattr(self.runner, 'call_count', 0) >= runner_loop_count:
+                    break
+            else:
+                pytest.fail("did not observe 2 loops in the runner thread")
+
+        def load_triggers(self):
+            """On second run, make sure that runner has called create_triggers in its second loop"""
+            super().load_triggers()
+            self.runner.load_triggers_count = getattr(self.runner, 'load_triggers_count', 0) + 1
+            if self.runner.load_triggers_count == 2:
+                self.wait_for_runner_loop(runner_loop_count=2)
+
+        def handle_events(self):
+            super().handle_events()
+            self.runner.handle_events_count = getattr(self.runner, 'handle_events_count', 0) + 1
+
+    trigger = TimeDeltaTrigger_(delta=datetime.timedelta(microseconds=1), filename=path.as_posix())
+    trigger_orm = Trigger.from_object(trigger)
+    trigger_orm.id = 1
+    session.add(trigger_orm)
+
+    dag = DagModel(dag_id='test-dag')
+    dag_run = DagRun(dag.dag_id, run_id='abc', run_type='none')
+    ti = TaskInstance(PythonOperator(task_id='dummy-task', python_callable=print), run_id=dag_run.run_id)
+    ti.dag_id = dag.dag_id
+    ti.trigger_id = 1
+    session.add(dag)
+    session.add(dag_run)
+    session.add(ti)
+
+    session.commit()
+
+    job = TriggererJob_()
+    job.runner = TriggerRunner_()
+    thread = Thread(target=job._execute)
+    thread.start()
+    try:
+        for _ in range(40):
+            time.sleep(0.1)
+            # ready to evaluate after 2 loops
+            if getattr(job.runner, 'loop_count', 0) >= 2:
+                break
+        else:
+            pytest.fail("did not observe 2 loops in the runner thread")
+    finally:
+        job.runner.stop = True
+        job.runner.join()
+        thread.join()
+    instances = path.read_text().splitlines()
+    assert len(instances) == 1
+
+
+@pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6")
 def test_trigger_from_dead_triggerer(session):
     """
     Checks that the triggerer will correctly claim a Trigger that is assigned to a