You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/03/22 14:37:29 UTC
[airflow] 23/28: Fix duplicate trigger creation race condition (#20699)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 6cd987078090cbadec28aa325800dc6447c4132d
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Thu Jan 6 15:16:02 2022 -0800
Fix duplicate trigger creation race condition (#20699)
The process for queueing up a trigger, for execution by the TriggerRunner, is handled by the TriggerJob's `load_triggers` method. It fetches the triggers that should be running according to the database, checks if they are running and if not it adds them to `TriggerRunner.to_create`. The problem is tha there's a small window of time between the moment a trigger (upon termination) is purged from the `TriggerRunner.triggers` set, and the time that the database is updated to reflect t [...]
To resolve this what we do here is, before adding a trigger to the `to_create` queue, instead of comparing against the "running" triggers, we compare against all triggers known to the TriggerRunner instance. When triggers move out of the `triggers` set they move into other data structures such as `events` and `failed_triggers` and `to_cancel`. So we union all of these and only create those triggers which the database indicates should exist _and_ which are know already being handled [...]
(cherry picked from commit 16b8c476518ed76e3689966ec4b0b788be935410)
---
airflow/jobs/triggerer_job.py | 12 +++-
tests/jobs/test_triggerer_job.py | 136 ++++++++++++++++++++++++++++++++++++++-
2 files changed, 142 insertions(+), 6 deletions(-)
diff --git a/airflow/jobs/triggerer_job.py b/airflow/jobs/triggerer_job.py
index 25a4c79..dff0e0f 100644
--- a/airflow/jobs/triggerer_job.py
+++ b/airflow/jobs/triggerer_job.py
@@ -381,10 +381,16 @@ class TriggerRunner(threading.Thread, LoggingMixin):
# line's execution, but we consider that safe, since there's a strict
# add -> remove -> never again lifecycle this function is already
# handling.
- current_trigger_ids = set(self.triggers.keys())
+ running_trigger_ids = set(self.triggers.keys())
+ known_trigger_ids = (
+ running_trigger_ids.union(x[0] for x in self.events)
+ .union(self.to_cancel)
+ .union(x[0] for x in self.to_create)
+ .union(self.failed_triggers)
+ )
# Work out the two difference sets
- new_trigger_ids = requested_trigger_ids.difference(current_trigger_ids)
- cancel_trigger_ids = current_trigger_ids.difference(requested_trigger_ids)
+ new_trigger_ids = requested_trigger_ids - known_trigger_ids
+ cancel_trigger_ids = running_trigger_ids - requested_trigger_ids
# Bulk-fetch new trigger records
new_triggers = Trigger.bulk_fetch(new_trigger_ids)
# Add in new triggers
diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py
index 5adc91f..870116a 100644
--- a/tests/jobs/test_triggerer_job.py
+++ b/tests/jobs/test_triggerer_job.py
@@ -16,29 +16,53 @@
# specific language governing permissions and limitations
# under the License.
+import asyncio
import datetime
import sys
import time
+from threading import Thread
import pytest
-from airflow.jobs.triggerer_job import TriggererJob
-from airflow.models import Trigger
+from airflow.jobs.triggerer_job import TriggererJob, TriggerRunner
+from airflow.models import DagModel, DagRun, TaskInstance, Trigger
from airflow.operators.dummy import DummyOperator
+from airflow.operators.python import PythonOperator
from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.triggers.testing import FailureTrigger, SuccessTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State, TaskInstanceState
-from tests.test_utils.db import clear_db_runs
+from tests.test_utils.db import clear_db_dags, clear_db_runs
+
+
+class TimeDeltaTrigger_(TimeDeltaTrigger):
+ def __init__(self, delta, filename):
+ super().__init__(delta=delta)
+ self.filename = filename
+ self.delta = delta
+
+ async def run(self):
+ with open(self.filename, 'at') as f:
+ f.write('hi\n')
+ async for event in super().run():
+ yield event
+
+ def serialize(self):
+ return (
+ "tests.jobs.test_triggerer_job.TimeDeltaTrigger_",
+ {"delta": self.delta, "filename": self.filename},
+ )
@pytest.fixture(autouse=True)
def clean_database():
"""Fixture that cleans the database before and after every test."""
clear_db_runs()
+ clear_db_dags()
yield # Test runs here
+ clear_db_dags()
clear_db_runs()
@@ -160,6 +184,112 @@ def test_trigger_lifecycle(session):
@pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6")
+def test_trigger_create_race_condition_18392(session, tmp_path):
+ """
+ This verifies the resolution of race condition documented in github issue #18392.
+ Triggers are queued for creation by TriggerJob.load_triggers.
+ There was a race condition where multiple triggers would be created unnecessarily.
+ What happens is the runner completes the trigger and purges from the "running" list.
+ Then job.load_triggers is called and it looks like the trigger is not running but should,
+ so it queues it again.
+
+ The scenario is as follows:
+ 1. job.load_triggers (trigger now queued)
+ 2. runner.create_triggers (trigger now running)
+ 3. job.handle_events (trigger still appears running so state not updated in DB)
+ 4. runner.cleanup_finished_triggers (trigger completed at this point; trigger from "running" set)
+ 5. job.load_triggers (trigger not running, but also not purged from DB, so it is queued again)
+ 6. runner.create_triggers (trigger created again)
+
+ This test verifies that under this scenario only one trigger is created.
+ """
+ path = tmp_path / 'test_trigger_bad_respawn.txt'
+
+ class TriggerRunner_(TriggerRunner):
+ """We do some waiting for main thread looping"""
+
+ async def wait_for_job_method_count(self, method, count):
+ for _ in range(30):
+ await asyncio.sleep(0.1)
+ if getattr(self, f'{method}_count', 0) >= count:
+ break
+ else:
+ pytest.fail(f"did not observe count {count} in job method {method}")
+
+ async def create_triggers(self):
+ """
+ On first run, wait for job.load_triggers to make sure they are queued
+ """
+ if getattr(self, 'loop_count', 0) == 0:
+ await self.wait_for_job_method_count('load_triggers', 1)
+ await super().create_triggers()
+ self.loop_count = getattr(self, 'loop_count', 0) + 1
+
+ async def cleanup_finished_triggers(self):
+ """On loop 1, make sure that job.handle_events was already called"""
+ if self.loop_count == 1:
+ await self.wait_for_job_method_count('handle_events', 1)
+ await super().cleanup_finished_triggers()
+
+ class TriggererJob_(TriggererJob):
+ """We do some waiting for runner thread looping (and track calls in job thread)"""
+
+ def wait_for_runner_loop(self, runner_loop_count):
+ for _ in range(30):
+ time.sleep(0.1)
+ if getattr(self.runner, 'call_count', 0) >= runner_loop_count:
+ break
+ else:
+ pytest.fail("did not observe 2 loops in the runner thread")
+
+ def load_triggers(self):
+ """On second run, make sure that runner has called create_triggers in its second loop"""
+ super().load_triggers()
+ self.runner.load_triggers_count = getattr(self.runner, 'load_triggers_count', 0) + 1
+ if self.runner.load_triggers_count == 2:
+ self.wait_for_runner_loop(runner_loop_count=2)
+
+ def handle_events(self):
+ super().handle_events()
+ self.runner.handle_events_count = getattr(self.runner, 'handle_events_count', 0) + 1
+
+ trigger = TimeDeltaTrigger_(delta=datetime.timedelta(microseconds=1), filename=path.as_posix())
+ trigger_orm = Trigger.from_object(trigger)
+ trigger_orm.id = 1
+ session.add(trigger_orm)
+
+ dag = DagModel(dag_id='test-dag')
+ dag_run = DagRun(dag.dag_id, run_id='abc', run_type='none')
+ ti = TaskInstance(PythonOperator(task_id='dummy-task', python_callable=print), run_id=dag_run.run_id)
+ ti.dag_id = dag.dag_id
+ ti.trigger_id = 1
+ session.add(dag)
+ session.add(dag_run)
+ session.add(ti)
+
+ session.commit()
+
+ job = TriggererJob_()
+ job.runner = TriggerRunner_()
+ thread = Thread(target=job._execute)
+ thread.start()
+ try:
+ for _ in range(40):
+ time.sleep(0.1)
+ # ready to evaluate after 2 loops
+ if getattr(job.runner, 'loop_count', 0) >= 2:
+ break
+ else:
+ pytest.fail("did not observe 2 loops in the runner thread")
+ finally:
+ job.runner.stop = True
+ job.runner.join()
+ thread.join()
+ instances = path.read_text().splitlines()
+ assert len(instances) == 1
+
+
+@pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6")
def test_trigger_from_dead_triggerer(session):
"""
Checks that the triggerer will correctly claim a Trigger that is assigned to a