You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/08/09 14:17:54 UTC

[airflow] branch main updated: Don't mistakenly take a lock on DagRun via ti.refresh_from_fb (#25312)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new be2b53eaaf Don't mistakenly take a lock on DagRun via ti.refresh_from_fb (#25312)
be2b53eaaf is described below

commit be2b53eaaf6fc136db8f3fa3edd797a6c529409a
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Tue Aug 9 15:17:41 2022 +0100

    Don't mistakenly take a lock on DagRun via ti.refresh_from_fb (#25312)
    
    In 2.2.0 we made TI.dag_run be automatically join-loaded, which is fine
    for most cases, but for `refresh_from_db` we don't need that (we don't
    access anything under ti.dag_run) and it's possible that when
    `lock_for_update=True` is passed we are locking more than we want to and
    _might_ cause deadlocks.
    
    Even if it doesn't, selecting more than we need is wasteful.
---
 airflow/models/taskinstance.py   | 28 ++++++++++++++++++----------
 tests/jobs/test_scheduler_job.py |  4 +++-
 2 files changed, 21 insertions(+), 11 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 656d775456..7930d91fb8 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -287,6 +287,7 @@ def clear_task_instances(
             if dag_run_state == DagRunState.QUEUED:
                 dr.last_scheduling_decision = None
                 dr.start_date = None
+    session.flush()
 
 
 class _LazyXComAccessIterator(collections.abc.Iterator):
@@ -848,28 +849,35 @@ class TaskInstance(Base, LoggingMixin):
         """
         self.log.debug("Refreshing TaskInstance %s from DB", self)
 
-        qry = session.query(TaskInstance).filter(
-            TaskInstance.dag_id == self.dag_id,
-            TaskInstance.task_id == self.task_id,
-            TaskInstance.run_id == self.run_id,
-            TaskInstance.map_index == self.map_index,
+        if self in session:
+            session.refresh(self, TaskInstance.__mapper__.column_attrs.keys())
+
+        qry = (
+            # To avoid joining any relationships, by default select all
+            # columns, not the object. This also means we get (effectively) a
+            # namedtuple back, not a TI object
+            session.query(*TaskInstance.__table__.columns).filter(
+                TaskInstance.dag_id == self.dag_id,
+                TaskInstance.task_id == self.task_id,
+                TaskInstance.run_id == self.run_id,
+                TaskInstance.map_index == self.map_index,
+            )
         )
 
         if lock_for_update:
             for attempt in run_with_db_retries(logger=self.log):
                 with attempt:
-                    ti: Optional[TaskInstance] = qry.with_for_update().first()
+                    ti: Optional[TaskInstance] = qry.with_for_update().one_or_none()
         else:
-            ti = qry.first()
+            ti = qry.one_or_none()
         if ti:
             # Fields ordered per model definition
             self.start_date = ti.start_date
             self.end_date = ti.end_date
             self.duration = ti.duration
             self.state = ti.state
-            # Get the raw value of try_number column, don't read through the
-            # accessor here otherwise it will be incremented by one already.
-            self.try_number = ti._try_number
+            # Since we selected columns, not the object, this is the raw value
+            self.try_number = ti.try_number
             self.max_tries = ti.max_tries
             self.hostname = ti.hostname
             self.unixname = ti.unixname
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index f433517511..ba055b295c 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -457,6 +457,7 @@ class TestSchedulerJob:
         ti1.state = State.SCHEDULED
 
         self.scheduler_job._critical_section_enqueue_task_instances(session)
+        session.flush()
         ti1.refresh_from_db(session=session)
         assert State.SCHEDULED == ti1.state
         session.rollback()
@@ -1315,7 +1316,8 @@ class TestSchedulerJob:
 
         with patch.object(BaseExecutor, 'queue_command') as mock_queue_command:
             self.scheduler_job._enqueue_task_instances_with_queued_state([ti], session=session)
-        ti.refresh_from_db()
+        session.flush()
+        ti.refresh_from_db(session=session)
         assert ti.state == State.NONE
         mock_queue_command.assert_not_called()