You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/10/18 13:10:45 UTC
[airflow] 29/41: Clean-ups around task-mapping code (#26879)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 05408cbe96a5793f59afdb44ebbecf20781ec239
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Tue Oct 11 09:32:01 2022 +0800
Clean-ups around task-mapping code (#26879)
(cherry picked from commit a2d872481e4075669eca35849431139af75b8c07)
---
airflow/models/dagrun.py | 54 ++++++++++++++++------------------------
airflow/models/mappedoperator.py | 10 +++++++-
tests/test_utils/mapping.py | 2 +-
3 files changed, 32 insertions(+), 34 deletions(-)
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 99d969ab19..299fc4b8cc 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -759,7 +759,7 @@ class DagRun(Base, LoggingMixin):
expansion_happened = True
if schedulable.state in SCHEDULEABLE_STATES:
task = schedulable.task
- if isinstance(schedulable.task, MappedOperator):
+ if isinstance(task, MappedOperator):
# Ensure the task indexes are complete
created = self._revise_mapped_task_indexes(task, session=session)
ready_tis.extend(created)
@@ -872,8 +872,6 @@ class DagRun(Base, LoggingMixin):
hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, 'is_noop', False)
dag = self.get_dag()
- task_ids: set[str] = set()
-
task_ids = self._check_for_removed_or_restored_tasks(
dag, task_instance_mutation_hook, session=session
)
@@ -951,8 +949,8 @@ class DagRun(Base, LoggingMixin):
ti.state = State.REMOVED
else:
# What if it is _now_ dynamically mapped, but wasn't before?
- task.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
- total_length = task.run_time_mapped_ti_count(self.run_id, session=session)
+ task.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
+ total_length = task.get_mapped_ti_count(self.run_id, session=session)
if total_length is None:
# Not all upstreams finished, so we can't tell what should be here. Remove everything.
@@ -1045,19 +1043,13 @@ class DagRun(Base, LoggingMixin):
:param session: the session to use
"""
- def expand_mapped_literals(
- task: Operator, sequence: Sequence[int] | None = None
- ) -> tuple[Operator, Sequence[int]]:
+ def expand_mapped_literals(task: Operator) -> tuple[Operator, Sequence[int]]:
if not task.is_mapped:
return (task, (-1,))
task = cast("MappedOperator", task)
- count = task.parse_time_mapped_ti_count or task.run_time_mapped_ti_count(
- self.run_id, session=session
- )
+ count = task.get_mapped_ti_count(self.run_id, session=session)
if not count:
return (task, (-1,))
- if sequence:
- return (task, sequence)
return (task, range(count))
tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
@@ -1110,21 +1102,19 @@ class DagRun(Base, LoggingMixin):
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()
- def _revise_mapped_task_indexes(self, task, session: Session):
+ def _revise_mapped_task_indexes(self, task: MappedOperator, session: Session) -> Iterable[TI]:
"""Check if task increased or reduced in length and handle appropriately"""
- from airflow.models.taskinstance import TaskInstance
from airflow.settings import task_instance_mutation_hook
- task.run_time_mapped_ti_count.cache_clear()
- total_length = (
- task.parse_time_mapped_ti_count
- or task.run_time_mapped_ti_count(self.run_id, session=session)
- or 0
- )
- query = session.query(TaskInstance.map_index).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == task.task_id,
- TaskInstance.run_id == self.run_id,
+ task.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
+ total_length = task.get_mapped_ti_count(self.run_id, session=session)
+ if total_length is None: # Upstreams not ready, don't need to revise this yet.
+ return []
+
+ query = session.query(TI.map_index).filter(
+ TI.dag_id == self.dag_id,
+ TI.task_id == task.task_id,
+ TI.run_id == self.run_id,
)
existing_indexes = {i for (i,) in query}
missing_indexes = set(range(total_length)).difference(existing_indexes)
@@ -1133,7 +1123,7 @@ class DagRun(Base, LoggingMixin):
if missing_indexes:
for index in missing_indexes:
- ti = TaskInstance(task, run_id=self.run_id, map_index=index, state=None)
+ ti = TI(task, run_id=self.run_id, map_index=index, state=None)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
@@ -1141,12 +1131,12 @@ class DagRun(Base, LoggingMixin):
session.flush()
created_tis.append(ti)
elif removed_indexes:
- session.query(TaskInstance).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == task.task_id,
- TaskInstance.run_id == self.run_id,
- TaskInstance.map_index.in_(removed_indexes),
- ).update({TaskInstance.state: TaskInstanceState.REMOVED})
+ session.query(TI).filter(
+ TI.dag_id == self.dag_id,
+ TI.task_id == task.task_id,
+ TI.run_id == self.run_id,
+ TI.map_index.in_(removed_indexes),
+ ).update({TI.state: TaskInstanceState.REMOVED})
session.flush()
return created_tis
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 8d1c3c4559..62cc22f379 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -727,15 +727,23 @@ class MappedOperator(AbstractOperator):
def parse_time_mapped_ti_count(self) -> int | None:
"""Number of mapped TaskInstances that can be created at DagRun create time.
+ This only considers literal mapped arguments, and would return *None*
+ when any non-literal values are used for mapping.
+
:return: None if non-literal mapped arg encountered, or the total
number of mapped TIs this task should have.
"""
return self._get_specified_expand_input().get_parse_time_mapped_ti_count()
@cache
- def run_time_mapped_ti_count(self, run_id: str, *, session: Session) -> int | None:
+ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int | None:
"""Number of mapped TaskInstances that can be created at run time.
+ This considers both literal and non-literal mapped arguments, and the
+ result is therefore available when all depended tasks have finished. The
+ return value should be identical to ``parse_time_mapped_ti_count`` if
+ all mapped arguments are literal.
+
:return: None if upstream tasks are not complete yet, or the total
number of mapped TIs this task should have.
"""
diff --git a/tests/test_utils/mapping.py b/tests/test_utils/mapping.py
index 5cfa230369..984446343c 100644
--- a/tests/test_utils/mapping.py
+++ b/tests/test_utils/mapping.py
@@ -42,4 +42,4 @@ def expand_mapped_task(
session.flush()
mapped.expand_mapped_task(run_id, session=session)
- mapped.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
+ mapped.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]