You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/08/09 14:17:54 UTC
[airflow] branch main updated: Don't mistakenly take a lock on DagRun via ti.refresh_from_fb (#25312)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new be2b53eaaf Don't mistakenly take a lock on DagRun via ti.refresh_from_fb (#25312)
be2b53eaaf is described below
commit be2b53eaaf6fc136db8f3fa3edd797a6c529409a
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Tue Aug 9 15:17:41 2022 +0100
Don't mistakenly take a lock on DagRun via ti.refresh_from_fb (#25312)
In 2.2.0 we made TI.dag_run be automatically join-loaded, which is fine
for most cases, but for `refresh_from_db` we don't need that (we don't
access anything under ti.dag_run) and it's possible that when
`lock_for_update=True` is passed we are locking more than we want to and
_might_ cause deadlocks.
Even if it doesn't, selecting more than we need is wasteful.
---
airflow/models/taskinstance.py | 28 ++++++++++++++++++----------
tests/jobs/test_scheduler_job.py | 4 +++-
2 files changed, 21 insertions(+), 11 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 656d775456..7930d91fb8 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -287,6 +287,7 @@ def clear_task_instances(
if dag_run_state == DagRunState.QUEUED:
dr.last_scheduling_decision = None
dr.start_date = None
+ session.flush()
class _LazyXComAccessIterator(collections.abc.Iterator):
@@ -848,28 +849,35 @@ class TaskInstance(Base, LoggingMixin):
"""
self.log.debug("Refreshing TaskInstance %s from DB", self)
- qry = session.query(TaskInstance).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.run_id == self.run_id,
- TaskInstance.map_index == self.map_index,
+ if self in session:
+ session.refresh(self, TaskInstance.__mapper__.column_attrs.keys())
+
+ qry = (
+ # To avoid joining any relationships, by default select all
+ # columns, not the object. This also means we get (effectively) a
+ # namedtuple back, not a TI object
+ session.query(*TaskInstance.__table__.columns).filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.run_id == self.run_id,
+ TaskInstance.map_index == self.map_index,
+ )
)
if lock_for_update:
for attempt in run_with_db_retries(logger=self.log):
with attempt:
- ti: Optional[TaskInstance] = qry.with_for_update().first()
+ ti: Optional[TaskInstance] = qry.with_for_update().one_or_none()
else:
- ti = qry.first()
+ ti = qry.one_or_none()
if ti:
# Fields ordered per model definition
self.start_date = ti.start_date
self.end_date = ti.end_date
self.duration = ti.duration
self.state = ti.state
- # Get the raw value of try_number column, don't read through the
- # accessor here otherwise it will be incremented by one already.
- self.try_number = ti._try_number
+ # Since we selected columns, not the object, this is the raw value
+ self.try_number = ti.try_number
self.max_tries = ti.max_tries
self.hostname = ti.hostname
self.unixname = ti.unixname
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index f433517511..ba055b295c 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -457,6 +457,7 @@ class TestSchedulerJob:
ti1.state = State.SCHEDULED
self.scheduler_job._critical_section_enqueue_task_instances(session)
+ session.flush()
ti1.refresh_from_db(session=session)
assert State.SCHEDULED == ti1.state
session.rollback()
@@ -1315,7 +1316,8 @@ class TestSchedulerJob:
with patch.object(BaseExecutor, 'queue_command') as mock_queue_command:
self.scheduler_job._enqueue_task_instances_with_queued_state([ti], session=session)
- ti.refresh_from_db()
+ session.flush()
+ ti.refresh_from_db(session=session)
assert ti.state == State.NONE
mock_queue_command.assert_not_called()