You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/03/20 06:53:47 UTC

[airflow] 03/25: Fix assignment of unassigned triggers (#21770)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7c8ce40598792f2aa764ae348498fed91a4efc47
Author: jkramer-ginkgo <68...@users.noreply.github.com>
AuthorDate: Sat Feb 26 14:25:15 2022 -0500

    Fix assignment of unassigned triggers (#21770)
    
    Previously, the query returned no alive triggerers which resulted
    in all triggers to be assigned to the current triggerer. This works
    fine, despite the logic bug, in the case where there's a single
    triggerer. But with multiple triggerers, concurrent iterations of
    the TriggerJob loop would bounce trigger ownership to whichever
    loop ran last.
    
    Addresses https://github.com/apache/airflow/issues/21616
    
    (cherry picked from commit b26d4d8a290ce0104992ba28850113490c1ca445)
---
 airflow/models/trigger.py    | 10 ++++++---
 tests/models/test_trigger.py | 52 ++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 59 insertions(+), 3 deletions(-)

diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py
index 5749589..aa0d2b1 100644
--- a/airflow/models/trigger.py
+++ b/airflow/models/trigger.py
@@ -17,7 +17,7 @@
 import datetime
 from typing import Any, Dict, List, Optional
 
-from sqlalchemy import Column, Integer, String, func
+from sqlalchemy import Column, Integer, String, func, or_
 
 from airflow.models.base import Base
 from airflow.models.taskinstance import TaskInstance
@@ -175,7 +175,7 @@ class Trigger(Base):
         alive_triggerer_ids = [
             row[0]
             for row in session.query(BaseJob.id).filter(
-                BaseJob.end_date is None,
+                BaseJob.end_date.is_(None),
                 BaseJob.latest_heartbeat > timezone.utcnow() - datetime.timedelta(seconds=30),
                 BaseJob.job_type == "TriggererJob",
             )
@@ -184,7 +184,11 @@ class Trigger(Base):
         # Find triggers who do NOT have an alive triggerer_id, and then assign
         # up to `capacity` of those to us.
         trigger_ids_query = (
-            session.query(cls.id).filter(cls.triggerer_id.notin_(alive_triggerer_ids)).limit(capacity).all()
+            session.query(cls.id)
+            # notin_ doesn't find NULL rows
+            .filter(or_(cls.triggerer_id.is_(None), cls.triggerer_id.notin_(alive_triggerer_ids)))
+            .limit(capacity)
+            .all()
         )
         session.query(cls).filter(cls.id.in_([i.id for i in trigger_ids_query])).update(
             {cls.triggerer_id: triggerer_id},
diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py
index aacfa88..99cd71f 100644
--- a/tests/models/test_trigger.py
+++ b/tests/models/test_trigger.py
@@ -15,8 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import datetime
+
 import pytest
 
+from airflow.jobs.triggerer_job import TriggererJob
 from airflow.models import TaskInstance, Trigger
 from airflow.operators.dummy import DummyOperator
 from airflow.triggers.base import TriggerEvent
@@ -36,9 +39,11 @@ def session():
 def clear_db(session):
     session.query(TaskInstance).delete()
     session.query(Trigger).delete()
+    session.query(TriggererJob).delete()
     yield session
     session.query(TaskInstance).delete()
     session.query(Trigger).delete()
+    session.query(TriggererJob).delete()
     session.commit()
 
 
@@ -124,3 +129,50 @@ def test_submit_failure(session, create_task_instance):
     updated_task_instance = session.query(TaskInstance).one()
     assert updated_task_instance.state == State.SCHEDULED
     assert updated_task_instance.next_method == "__fail__"
+
+
+def test_assign_unassigned(session, create_task_instance):
+    """
+    Tests that unassigned triggers of all appropriate states are assigned.
+    """
+    finished_triggerer = TriggererJob(None, heartrate=10, state=State.SUCCESS)
+    finished_triggerer.end_date = timezone.utcnow() - datetime.timedelta(hours=1)
+    session.add(finished_triggerer)
+    assert not finished_triggerer.is_alive()
+    healthy_triggerer = TriggererJob(None, heartrate=10, state=State.RUNNING)
+    session.add(healthy_triggerer)
+    assert healthy_triggerer.is_alive()
+    new_triggerer = TriggererJob(None, heartrate=10, state=State.RUNNING)
+    session.add(new_triggerer)
+    assert new_triggerer.is_alive()
+    session.commit()
+    trigger_on_healthy_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+    trigger_on_healthy_triggerer.id = 1
+    trigger_on_healthy_triggerer.triggerer_id = healthy_triggerer.id
+    trigger_on_killed_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+    trigger_on_killed_triggerer.id = 2
+    trigger_on_killed_triggerer.triggerer_id = finished_triggerer.id
+    trigger_unassigned_to_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+    trigger_unassigned_to_triggerer.id = 3
+    assert trigger_unassigned_to_triggerer.triggerer_id is None
+    session.add(trigger_on_healthy_triggerer)
+    session.add(trigger_on_killed_triggerer)
+    session.add(trigger_unassigned_to_triggerer)
+    session.commit()
+    assert session.query(Trigger).count() == 3
+    Trigger.assign_unassigned(new_triggerer.id, 100, session=session)
+    session.expire_all()
+    # Check that trigger on killed triggerer and unassigned trigger are assigned to new triggerer
+    assert (
+        session.query(Trigger).filter(Trigger.id == trigger_on_killed_triggerer.id).one().triggerer_id
+        == new_triggerer.id
+    )
+    assert (
+        session.query(Trigger).filter(Trigger.id == trigger_unassigned_to_triggerer.id).one().triggerer_id
+        == new_triggerer.id
+    )
+    # Check that trigger on healthy triggerer still assigned to existing triggerer
+    assert (
+        session.query(Trigger).filter(Trigger.id == trigger_on_healthy_triggerer.id).one().triggerer_id
+        == healthy_triggerer.id
+    )