You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/06/18 07:33:03 UTC
[airflow] branch main updated: Fix mapped task immutability after clear (#23667)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b692517ce3 Fix mapped task immutability after clear (#23667)
b692517ce3 is described below
commit b692517ce3aafb276e9d23570e9734c30a5f3d1f
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Sat Jun 18 08:32:38 2022 +0100
Fix mapped task immutability after clear (#23667)
We should be able to detect if the structure of mapped task has changed
and verify the integrity.
This PR ensures this
Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
airflow/models/dagrun.py | 114 +++++++++++++++++++++++++------
tests/models/test_dagrun.py | 161 +++++++++++++++++++++++++++++++++++++++++++-
2 files changed, 251 insertions(+), 24 deletions(-)
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 216272c79a..05713b96c7 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -642,15 +642,9 @@ class DagRun(Base, LoggingMixin):
tis = list(self.get_task_instances(session=session, state=State.task_states))
self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
dag = self.get_dag()
- for ti in tis:
- try:
- ti.task = dag.get_task(ti.task_id)
- except TaskNotFound:
- self.log.warning(
- "Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id
- )
- ti.state = State.REMOVED
- session.flush()
+ missing_indexes = self._find_missing_task_indexes(dag, tis, session=session)
+ if missing_indexes:
+ self.verify_integrity(missing_indexes=missing_indexes, session=session)
unfinished_tis = [t for t in tis if t.state in State.unfinished]
finished_tis = [t for t in tis if t.state in State.finished]
@@ -811,11 +805,17 @@ class DagRun(Base, LoggingMixin):
Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration)
@provide_session
- def verify_integrity(self, session: Session = NEW_SESSION):
+ def verify_integrity(
+ self,
+ *,
+ missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]] = None,
+ session: Session = NEW_SESSION,
+ ):
"""
Verifies the DagRun by checking for removed tasks or tasks that are not in the
database yet. It will set state to removed or add the task if required.
+ :missing_indexes: A dictionary of task vs indexes that are missing.
:param session: Sqlalchemy ORM Session
"""
from airflow.settings import task_instance_mutation_hook
@@ -824,9 +824,16 @@ class DagRun(Base, LoggingMixin):
hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)
dag = self.get_dag()
- task_ids = self._check_for_removed_or_restored_tasks(
- dag, task_instance_mutation_hook, session=session
- )
+ task_ids: Set[str] = set()
+ if missing_indexes:
+ tis = self.get_task_instances(session=session)
+ for ti in tis:
+ task_instance_mutation_hook(ti)
+ task_ids.add(ti.task_id)
+ else:
+ task_ids, missing_indexes = self._check_for_removed_or_restored_tasks(
+ dag, task_instance_mutation_hook, session=session
+ )
def task_filter(task: "Operator") -> bool:
return task.task_id not in task_ids and (
@@ -841,27 +848,29 @@ class DagRun(Base, LoggingMixin):
task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)
# Create the missing tasks, including mapped tasks
- tasks = self._create_missing_tasks(dag, task_creator, task_filter, session=session)
+ tasks = self._create_missing_tasks(dag, task_creator, task_filter, missing_indexes, session=session)
self._create_task_instances(dag.dag_id, tasks, created_counts, hook_is_noop, session=session)
def _check_for_removed_or_restored_tasks(
self, dag: "DAG", ti_mutation_hook, *, session: Session
- ) -> Set[str]:
+ ) -> Tuple[Set[str], Dict["MappedOperator", Sequence[int]]]:
"""
- Check for removed tasks/restored tasks.
+ Check for removed tasks/restored/missing tasks.
:param dag: DAG object corresponding to the dagrun
:param ti_mutation_hook: task_instance_mutation_hook function
:param session: Sqlalchemy ORM Session
- :return: List of task_ids in the dagrun
+ :return: List of task_ids in the dagrun and missing task indexes
"""
tis = self.get_task_instances(session=session)
# check for removed or restored tasks
task_ids = set()
+ existing_indexes: Dict["MappedOperator", List[int]] = defaultdict(list)
+ expected_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
for ti in tis:
ti_mutation_hook(ti)
task_ids.add(ti.task_id)
@@ -902,7 +911,8 @@ class DagRun(Base, LoggingMixin):
else:
self.log.info("Restoring mapped task '%s'", ti)
Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
- ti.state = State.NONE
+ existing_indexes[task].append(ti.map_index)
+ expected_indexes[task] = range(num_mapped_tis)
else:
# What if it is _now_ dynamically mapped, but wasn't before?
total_length = task.run_time_mapped_ti_count(self.run_id, session=session)
@@ -923,8 +933,16 @@ class DagRun(Base, LoggingMixin):
total_length,
)
ti.state = State.REMOVED
- ...
- return task_ids
+ else:
+ self.log.info("Restoring mapped task '%s'", ti)
+ Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
+ existing_indexes[task].append(ti.map_index)
+ expected_indexes[task] = range(total_length)
+ # Check if we have some missing indexes to create ti for
+ missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
+ for k, v in existing_indexes.items():
+ missing_indexes.update({k: list(set(expected_indexes[k]).difference(v))})
+ return task_ids, missing_indexes
def _get_task_creator(
self, created_counts: Dict[str, int], ti_mutation_hook: Callable, hook_is_noop: bool
@@ -961,7 +979,13 @@ class DagRun(Base, LoggingMixin):
return creator
def _create_missing_tasks(
- self, dag: "DAG", task_creator: Callable, task_filter: Callable, *, session: Session
+ self,
+ dag: "DAG",
+ task_creator: Callable,
+ task_filter: Callable,
+ missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]],
+ *,
+ session: Session,
) -> Iterable["Operator"]:
"""
Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
@@ -972,7 +996,9 @@ class DagRun(Base, LoggingMixin):
:param session: the session to use
"""
- def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]:
+ def expand_mapped_literals(
+ task: "Operator", sequence: Union[Sequence[int], None] = None
+ ) -> Tuple["Operator", Sequence[int]]:
if not task.is_mapped:
return (task, (-1,))
task = cast("MappedOperator", task)
@@ -981,11 +1007,19 @@ class DagRun(Base, LoggingMixin):
)
if not count:
return (task, (-1,))
+ if sequence:
+ return (task, sequence)
return (task, range(count))
tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, tasks_and_map_idxs))
+ if missing_indexes:
+ # If there are missing indexes, override the tasks to create
+ new_tasks_and_map_idxs = itertools.starmap(
+ expand_mapped_literals, [(k, v) for k, v in missing_indexes.items() if len(v) > 0]
+ )
+ tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, new_tasks_and_map_idxs))
return tasks
def _create_task_instances(
@@ -1027,6 +1061,42 @@ class DagRun(Base, LoggingMixin):
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()
+ def _find_missing_task_indexes(self, dag, tis, *, session) -> Dict["MappedOperator", Sequence[int]]:
+ """
+ Here we check if the length of the mapped task instances changed
+ at runtime. If so, we find the missing indexes.
+
+ This function also marks task instances with missing tasks as REMOVED.
+
+ :param dag: DAG object corresponding to the dagrun
+ :param tis: task instances to check
+ :param session: the session to use
+ """
+ existing_indexes: Dict["MappedOperator", list] = defaultdict(list)
+ new_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
+ for ti in tis:
+ try:
+ task = ti.task = dag.get_task(ti.task_id)
+ except TaskNotFound:
+ self.log.error("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id)
+
+ ti.state = State.REMOVED
+ session.flush()
+ continue
+ if not task.is_mapped:
+ continue
+ # skip unexpanded tasks and also tasks that expands with literal arguments
+ if ti.map_index < 0 or task.parse_time_mapped_ti_count:
+ continue
+ existing_indexes[task].append(ti.map_index)
+ task.run_time_mapped_ti_count.cache_clear()
+ new_length = task.run_time_mapped_ti_count(self.run_id, session=session) or 0
+ new_indexes[task] = range(new_length)
+ missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
+ for k, v in existing_indexes.items():
+ missing_indexes.update({k: list(set(new_indexes[k]).difference(v))})
+ return missing_indexes
+
@staticmethod
def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional['DagRun']:
"""
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 14f4b7f34b..6c3cc1c91c 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -41,7 +41,7 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE as _DEFAULT_DATE
-from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs
+from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, clear_db_variables
from tests.test_utils.mock_operators import MockOperator
DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE)
@@ -54,11 +54,13 @@ class TestDagRun:
clear_db_runs()
clear_db_pools()
clear_db_dags()
+ clear_db_variables()
def teardown_method(self) -> None:
clear_db_runs()
clear_db_pools()
clear_db_dags()
+ clear_db_variables()
def create_dag_run(
self,
@@ -922,7 +924,7 @@ def test_verify_integrity_task_start_and_end_date(Stats_incr, session, run_type,
session.add(dag_run)
session.flush()
- dag_run.verify_integrity(session)
+ dag_run.verify_integrity(session=session)
tis = dag_run.task_instances
assert len(tis) == expected_tis
@@ -1050,6 +1052,161 @@ def test_mapped_literal_to_xcom_arg_verify_integrity(dag_maker, session):
]
+def test_mapped_literal_length_increase_adds_additional_ti(dag_maker, session):
+ """Test that when the length of mapped literal increases, additional ti is added"""
+
+ with dag_maker(session=session) as dag:
+
+ @task
+ def task_2(arg2):
+ ...
+
+ task_2.expand(arg2=[1, 2, 3, 4])
+
+ dr = dag_maker.create_dagrun()
+ tis = dr.get_task_instances()
+ indices = [(ti.map_index, ti.state) for ti in tis]
+ assert sorted(indices) == [
+ (0, State.NONE),
+ (1, State.NONE),
+ (2, State.NONE),
+ (3, State.NONE),
+ ]
+
+ # Now "increase" the length of literal
+ dag._remove_task('task_2')
+
+ with dag:
+ task_2.expand(arg2=[1, 2, 3, 4, 5]).operator
+
+ # At this point, we need to test that the change works on the serialized
+ # DAG (which is what the scheduler operates on)
+ serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ dr.dag = serialized_dag
+ # Since we change the literal on the dag file itself, the dag_hash will
+ # change which will have the scheduler verify the dr integrity
+ dr.verify_integrity()
+
+ tis = dr.get_task_instances()
+ indices = [(ti.map_index, ti.state) for ti in tis]
+ assert sorted(indices) == [
+ (0, State.NONE),
+ (1, State.NONE),
+ (2, State.NONE),
+ (3, State.NONE),
+ (4, State.NONE),
+ ]
+
+
+def test_mapped_literal_length_reduction_adds_removed_state(dag_maker, session):
+ """Test that when the length of mapped literal reduces, removed state is added"""
+
+ with dag_maker(session=session) as dag:
+
+ @task
+ def task_2(arg2):
+ ...
+
+ task_2.expand(arg2=[1, 2, 3, 4])
+
+ dr = dag_maker.create_dagrun()
+ tis = dr.get_task_instances()
+ indices = [(ti.map_index, ti.state) for ti in tis]
+ assert sorted(indices) == [
+ (0, State.NONE),
+ (1, State.NONE),
+ (2, State.NONE),
+ (3, State.NONE),
+ ]
+
+ # Now "reduce" the length of literal
+ dag._remove_task('task_2')
+
+ with dag:
+ task_2.expand(arg2=[1, 2]).operator
+
+ # At this point, we need to test that the change works on the serialized
+ # DAG (which is what the scheduler operates on)
+ serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ dr.dag = serialized_dag
+ # Since we change the literal on the dag file itself, the dag_hash will
+ # change which will have the scheduler verify the dr integrity
+ dr.verify_integrity()
+
+ tis = dr.get_task_instances()
+ indices = [(ti.map_index, ti.state) for ti in tis]
+ assert sorted(indices) == [
+ (0, State.NONE),
+ (1, State.NONE),
+ (2, State.REMOVED),
+ (3, State.REMOVED),
+ ]
+
+
+def test_mapped_literal_length_increase_at_runtime_adds_additional_tis(dag_maker, session):
+ """Test that when the length of mapped literal increases at runtime, additional ti is added"""
+ from airflow.models import Variable
+
+ Variable.set(key='arg1', value=[1, 2, 3])
+
+ @task
+ def task_1():
+ return Variable.get('arg1', deserialize_json=True)
+
+ with dag_maker(session=session) as dag:
+
+ @task
+ def task_2(arg2):
+ ...
+
+ task_2.expand(arg2=task_1())
+
+ dr = dag_maker.create_dagrun()
+ ti = dr.get_task_instance(task_id='task_1')
+ ti.run()
+ dr.task_instance_scheduling_decisions()
+ tis = dr.get_task_instances()
+ indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+ assert sorted(indices) == [
+ (0, State.NONE),
+ (1, State.NONE),
+ (2, State.NONE),
+ ]
+
+ # Now "clear" and "increase" the length of literal
+ dag.clear()
+ Variable.set(key='arg1', value=[1, 2, 3, 4])
+
+ with dag:
+ task_2.expand(arg2=task_1()).operator
+
+ # At this point, we need to test that the change works on the serialized
+ # DAG (which is what the scheduler operates on)
+ serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ dr.dag = serialized_dag
+
+ # Run the first task again to get the new lengths
+ ti = dr.get_task_instance(task_id='task_1')
+ task1 = dag.get_task('task_1')
+ ti.refresh_from_task(task1)
+ ti.run()
+
+ # this would be called by the localtask job
+ dr.task_instance_scheduling_decisions()
+ tis = dr.get_task_instances()
+
+ indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+ assert sorted(indices) == [
+ (0, State.NONE),
+ (1, State.NONE),
+ (2, State.NONE),
+ (3, State.NONE),
+ ]
+
+
@pytest.mark.need_serialized_dag
def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session):
literal = [1, 2, 3, 4]