You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/10/18 13:10:45 UTC

[airflow] 29/41: Clean-ups around task-mapping code (#26879)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 05408cbe96a5793f59afdb44ebbecf20781ec239
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Tue Oct 11 09:32:01 2022 +0800

    Clean-ups around task-mapping code (#26879)
    
    (cherry picked from commit a2d872481e4075669eca35849431139af75b8c07)
---
 airflow/models/dagrun.py         | 54 ++++++++++++++++------------------------
 airflow/models/mappedoperator.py | 10 +++++++-
 tests/test_utils/mapping.py      |  2 +-
 3 files changed, 32 insertions(+), 34 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 99d969ab19..299fc4b8cc 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -759,7 +759,7 @@ class DagRun(Base, LoggingMixin):
                 expansion_happened = True
             if schedulable.state in SCHEDULEABLE_STATES:
                 task = schedulable.task
-                if isinstance(schedulable.task, MappedOperator):
+                if isinstance(task, MappedOperator):
                     # Ensure the task indexes are complete
                     created = self._revise_mapped_task_indexes(task, session=session)
                     ready_tis.extend(created)
@@ -872,8 +872,6 @@ class DagRun(Base, LoggingMixin):
         hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, 'is_noop', False)
 
         dag = self.get_dag()
-        task_ids: set[str] = set()
-
         task_ids = self._check_for_removed_or_restored_tasks(
             dag, task_instance_mutation_hook, session=session
         )
@@ -951,8 +949,8 @@ class DagRun(Base, LoggingMixin):
                     ti.state = State.REMOVED
             else:
                 #  What if it is _now_ dynamically mapped, but wasn't before?
-                task.run_time_mapped_ti_count.cache_clear()  # type: ignore[attr-defined]
-                total_length = task.run_time_mapped_ti_count(self.run_id, session=session)
+                task.get_mapped_ti_count.cache_clear()  # type: ignore[attr-defined]
+                total_length = task.get_mapped_ti_count(self.run_id, session=session)
 
                 if total_length is None:
                     # Not all upstreams finished, so we can't tell what should be here. Remove everything.
@@ -1045,19 +1043,13 @@ class DagRun(Base, LoggingMixin):
         :param session: the session to use
         """
 
-        def expand_mapped_literals(
-            task: Operator, sequence: Sequence[int] | None = None
-        ) -> tuple[Operator, Sequence[int]]:
+        def expand_mapped_literals(task: Operator) -> tuple[Operator, Sequence[int]]:
             if not task.is_mapped:
                 return (task, (-1,))
             task = cast("MappedOperator", task)
-            count = task.parse_time_mapped_ti_count or task.run_time_mapped_ti_count(
-                self.run_id, session=session
-            )
+            count = task.get_mapped_ti_count(self.run_id, session=session)
             if not count:
                 return (task, (-1,))
-            if sequence:
-                return (task, sequence)
             return (task, range(count))
 
         tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
@@ -1110,21 +1102,19 @@ class DagRun(Base, LoggingMixin):
             # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
             session.rollback()
 
-    def _revise_mapped_task_indexes(self, task, session: Session):
+    def _revise_mapped_task_indexes(self, task: MappedOperator, session: Session) -> Iterable[TI]:
         """Check if task increased or reduced in length and handle appropriately"""
-        from airflow.models.taskinstance import TaskInstance
         from airflow.settings import task_instance_mutation_hook
 
-        task.run_time_mapped_ti_count.cache_clear()
-        total_length = (
-            task.parse_time_mapped_ti_count
-            or task.run_time_mapped_ti_count(self.run_id, session=session)
-            or 0
-        )
-        query = session.query(TaskInstance.map_index).filter(
-            TaskInstance.dag_id == self.dag_id,
-            TaskInstance.task_id == task.task_id,
-            TaskInstance.run_id == self.run_id,
+        task.get_mapped_ti_count.cache_clear()  # type: ignore[attr-defined]
+        total_length = task.get_mapped_ti_count(self.run_id, session=session)
+        if total_length is None:  # Upstreams not ready, don't need to revise this yet.
+            return []
+
+        query = session.query(TI.map_index).filter(
+            TI.dag_id == self.dag_id,
+            TI.task_id == task.task_id,
+            TI.run_id == self.run_id,
         )
         existing_indexes = {i for (i,) in query}
         missing_indexes = set(range(total_length)).difference(existing_indexes)
@@ -1133,7 +1123,7 @@ class DagRun(Base, LoggingMixin):
 
         if missing_indexes:
             for index in missing_indexes:
-                ti = TaskInstance(task, run_id=self.run_id, map_index=index, state=None)
+                ti = TI(task, run_id=self.run_id, map_index=index, state=None)
                 self.log.debug("Expanding TIs upserted %s", ti)
                 task_instance_mutation_hook(ti)
                 ti = session.merge(ti)
@@ -1141,12 +1131,12 @@ class DagRun(Base, LoggingMixin):
                 session.flush()
                 created_tis.append(ti)
         elif removed_indexes:
-            session.query(TaskInstance).filter(
-                TaskInstance.dag_id == self.dag_id,
-                TaskInstance.task_id == task.task_id,
-                TaskInstance.run_id == self.run_id,
-                TaskInstance.map_index.in_(removed_indexes),
-            ).update({TaskInstance.state: TaskInstanceState.REMOVED})
+            session.query(TI).filter(
+                TI.dag_id == self.dag_id,
+                TI.task_id == task.task_id,
+                TI.run_id == self.run_id,
+                TI.map_index.in_(removed_indexes),
+            ).update({TI.state: TaskInstanceState.REMOVED})
             session.flush()
         return created_tis
 
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 8d1c3c4559..62cc22f379 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -727,15 +727,23 @@ class MappedOperator(AbstractOperator):
     def parse_time_mapped_ti_count(self) -> int | None:
         """Number of mapped TaskInstances that can be created at DagRun create time.
 
+        This only considers literal mapped arguments, and would return *None*
+        when any non-literal values are used for mapping.
+
         :return: None if non-literal mapped arg encountered, or the total
             number of mapped TIs this task should have.
         """
         return self._get_specified_expand_input().get_parse_time_mapped_ti_count()
 
     @cache
-    def run_time_mapped_ti_count(self, run_id: str, *, session: Session) -> int | None:
+    def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int | None:
         """Number of mapped TaskInstances that can be created at run time.
 
+        This considers both literal and non-literal mapped arguments, and the
+        result is therefore available when all depended tasks have finished. The
+        return value should be identical to ``parse_time_mapped_ti_count`` if
+        all mapped arguments are literal.
+
         :return: None if upstream tasks are not complete yet, or the total
             number of mapped TIs this task should have.
         """
diff --git a/tests/test_utils/mapping.py b/tests/test_utils/mapping.py
index 5cfa230369..984446343c 100644
--- a/tests/test_utils/mapping.py
+++ b/tests/test_utils/mapping.py
@@ -42,4 +42,4 @@ def expand_mapped_task(
     session.flush()
 
     mapped.expand_mapped_task(run_id, session=session)
-    mapped.run_time_mapped_ti_count.cache_clear()  # type: ignore[attr-defined]
+    mapped.get_mapped_ti_count.cache_clear()  # type: ignore[attr-defined]