You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/20 19:00:47 UTC
[airflow] 05/19: fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
This is an automated email from the ASF dual-hosted git repository.
bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit c5cc48f9f9161b187368d99b551c652e1da03de5
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Apr 13 20:16:56 2022 +0100
fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
---
airflow/api/common/mark_tasks.py | 38 ++++-----
airflow/models/dag.py | 159 ++++++++++++++++++++----------------
airflow/www/views.py | 44 ++++++----
tests/api/common/test_mark_tasks.py | 6 +-
tests/models/test_dag.py | 2 +-
5 files changed, 133 insertions(+), 116 deletions(-)
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 1d4709fb82..84fd48f4e4 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -18,9 +18,9 @@
"""Marks tasks APIs."""
from datetime import datetime
-from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
-from sqlalchemy import or_
+from sqlalchemy import or_, tuple_
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm.session import Session as SASession
@@ -32,7 +32,6 @@ from airflow.operators.subdag import SubDagOperator
from airflow.utils import timezone
from airflow.utils.helpers import exactly_one
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType
@@ -79,7 +78,7 @@ def _create_dagruns(
@provide_session
def set_state(
*,
- tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]],
+ tasks: Union[Iterable[Operator], Iterable[Tuple[Operator, int]]],
run_id: Optional[str] = None,
execution_date: Optional[datetime] = None,
upstream: bool = False,
@@ -97,7 +96,7 @@ def set_state(
tasks that did not exist. It will not create dag runs that are missing
on the schedule (but it will as for subdag dag runs if needed).
- :param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
+ :param tasks: the iterable of tasks or task, map_index tuple from which to work.
task.task.dag needs to be set
:param run_id: the run_id of the dagrun to start looking from
:param execution_date: the execution date from which to start looking(deprecated)
@@ -120,7 +119,9 @@ def set_state(
if execution_date and not timezone.is_localized(execution_date):
raise ValueError(f"Received non-localized date {execution_date}")
- task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks}
+ t_dags = {task.dag for task in tasks if not isinstance(task, tuple)}
+ t_dags_2 = {item[0].dag for item in tasks if isinstance(item, tuple)}
+ task_dags = t_dags | t_dags_2
if len(task_dags) > 1:
raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
dag = next(iter(task_dags))
@@ -135,12 +136,6 @@ def set_state(
dag_run_ids = get_run_ids(dag, run_id, future, past)
task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
task_ids = [task_id for task_id, _ in task_id_map_index_list]
- # check if task_id_map_index_list contains map_index of None
- # if it contains None, there was no map_index supplied for the task
- for _, index in task_id_map_index_list:
- if index is None:
- task_id_map_index_list = [task_id for task_id, _ in task_id_map_index_list]
- break
confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
confirmed_dates = [info.logical_date for info in confirmed_infos]
@@ -187,26 +182,20 @@ def get_all_dag_task_query(
dag: DAG,
session: SASession,
state: TaskInstanceState,
- task_ids: Union[List[str], List[Tuple[str, int]]],
+ task_id_map_index_list: List[Tuple[str, int]],
confirmed_dates: Iterable[datetime],
):
"""Get all tasks of the main dag that will be affected by a state change"""
- is_string_list = isinstance(task_ids[0], str)
qry_dag = (
session.query(TaskInstance)
.join(TaskInstance.dag_run)
.filter(
TaskInstance.dag_id == dag.dag_id,
DagRun.execution_date.in_(confirmed_dates),
+ tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_id_map_index_list),
)
- )
-
- if is_string_list:
- qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids))
- else:
- qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids))
- qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
- contains_eager(TaskInstance.dag_run)
+ .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
+ .options(contains_eager(TaskInstance.dag_run))
)
return qry_dag
@@ -282,13 +271,14 @@ def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str]) -> Iterator[_DagR
yield _DagRunInfo(dag_run.logical_date, dag.get_run_data_interval(dag_run))
-def find_task_relatives(tasks, downstream, upstream):
+@provide_session
+def find_task_relatives(tasks, downstream, upstream, session: SASession = NEW_SESSION):
"""Yield task ids and optionally ancestor and descendant ids."""
for item in tasks:
if isinstance(item, tuple):
task, map_index = item
else:
- task, map_index = item, None
+ task, map_index = item, -1
yield task.task_id, map_index
if downstream:
for relative in task.get_flat_relatives(upstream=False):
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index e9c33acb72..931fd469d7 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -39,7 +39,6 @@ from typing import (
Iterable,
List,
Optional,
- Sequence,
Set,
Tuple,
Type,
@@ -52,7 +51,7 @@ import jinja2
import pendulum
from dateutil.relativedelta import relativedelta
from pendulum.tz.timezone import Timezone
-from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_
+from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_, tuple_
from sqlalchemy.orm import backref, joinedload, relationship
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
@@ -85,7 +84,7 @@ from airflow.utils.file import correct_maybe_zipped
from airflow.utils.helpers import exactly_one, validate_key
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
@@ -1341,33 +1340,47 @@ class DAG(LoggingMixin):
start_date = (timezone.utcnow() - timedelta(30)).replace(
hour=0, minute=0, second=0, microsecond=0
)
- query = self._get_task_instances(
- task_ids=None,
- start_date=start_date,
- end_date=end_date,
- run_id=None,
- state=state or (),
- include_subdags=False,
- include_parentdag=False,
- include_dependent_dags=False,
- exclude_task_ids=(),
- session=session,
+
+ if state is None:
+ state = []
+
+ return (
+ cast(
+ Query,
+ self._get_task_instances(
+ task_ids=None,
+ task_ids_and_map_indexes=None,
+ start_date=start_date,
+ end_date=end_date,
+ run_id=None,
+ state=state,
+ include_subdags=False,
+ include_parentdag=False,
+ include_dependent_dags=False,
+ exclude_task_ids=cast(List[str], []),
+ exclude_task_ids_and_map_indexes=None,
+ session=session,
+ ),
+ )
+ .order_by(DagRun.execution_date)
+ .all()
)
- return cast(Query, query).order_by(DagRun.execution_date).all()
@overload
def _get_task_instances(
self,
*,
- task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+ task_ids: Iterable[str],
+ task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
start_date: Optional[datetime],
end_date: Optional[datetime],
run_id: Optional[str],
- state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+ state: Union[TaskInstanceState, List[TaskInstanceState]],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
- exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+ exclude_task_ids: Collection[str],
+ exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
session: Session,
dag_bag: Optional["DagBag"] = ...,
) -> Iterable[TaskInstance]:
@@ -1377,16 +1390,18 @@ class DAG(LoggingMixin):
def _get_task_instances(
self,
*,
- task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+ task_ids: Iterable[str],
+ task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
as_pk_tuple: Literal[True],
start_date: Optional[datetime],
end_date: Optional[datetime],
run_id: Optional[str],
- state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+ state: Union[TaskInstanceState, List[TaskInstanceState]],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
- exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+ exclude_task_ids: Collection[str],
+ exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
session: Session,
dag_bag: Optional["DagBag"] = ...,
recursion_depth: int = ...,
@@ -1398,16 +1413,18 @@ class DAG(LoggingMixin):
def _get_task_instances(
self,
*,
- task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+ task_ids: Iterable[str],
+ task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
as_pk_tuple: Literal[True, None] = None,
start_date: Optional[datetime],
end_date: Optional[datetime],
run_id: Optional[str],
- state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+ state: Union[TaskInstanceState, List[TaskInstanceState]],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
- exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+ exclude_task_ids: Collection[str],
+ exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
session: Session,
dag_bag: Optional["DagBag"] = None,
recursion_depth: int = 0,
@@ -1431,6 +1448,11 @@ class DAG(LoggingMixin):
tis = session.query(TaskInstance)
tis = tis.join(TaskInstance.dag_run)
+ if task_ids is not None: # task not mapped
+ task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
+ if exclude_task_ids and len(exclude_task_ids) > 0: # task not mapped
+ exclude_task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
+
if include_subdags:
# Crafting the right filter for dag_id and task_ids combo
conditions = []
@@ -1445,13 +1467,10 @@ class DAG(LoggingMixin):
tis = tis.filter(TaskInstance.run_id == run_id)
if start_date:
tis = tis.filter(DagRun.execution_date >= start_date)
-
- if task_ids is None:
- pass # Disable filter if not set.
- elif isinstance(next(iter(task_ids), None), str):
- tis = tis.filter(TI.task_id.in_(task_ids))
- else:
- tis = tis.filter(tuple_in_condition((TI.task_id, TI.map_index), task_ids))
+ if task_ids_and_map_indexes:
+ tis = tis.filter(
+ tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids_and_map_indexes)
+ )
# This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
if end_date or not self.allow_future_exec_dates:
@@ -1490,6 +1509,7 @@ class DAG(LoggingMixin):
result.update(
p_dag._get_task_instances(
task_ids=task_ids,
+ task_ids_and_map_indexes=task_ids_and_map_indexes,
start_date=start_date,
end_date=end_date,
run_id=None,
@@ -1499,6 +1519,7 @@ class DAG(LoggingMixin):
include_dependent_dags=include_dependent_dags,
as_pk_tuple=True,
exclude_task_ids=exclude_task_ids,
+ exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
session=session,
dag_bag=dag_bag,
recursion_depth=recursion_depth,
@@ -1566,7 +1587,7 @@ class DAG(LoggingMixin):
)
result.update(
downstream._get_task_instances(
- task_ids=None,
+ task_ids_and_map_indexes=None,
run_id=tii.run_id,
start_date=None,
end_date=None,
@@ -1575,7 +1596,7 @@ class DAG(LoggingMixin):
include_dependent_dags=include_dependent_dags,
include_parentdag=False,
as_pk_tuple=True,
- exclude_task_ids=exclude_task_ids,
+ exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
dag_bag=dag_bag,
session=session,
recursion_depth=recursion_depth + 1,
@@ -1589,29 +1610,25 @@ class DAG(LoggingMixin):
if as_pk_tuple:
result.update(TaskInstanceKey(*cols) for cols in tis.all())
else:
- result.update(ti.key for ti in tis)
-
- if exclude_task_ids is not None:
- result = {
- task
- for task in result
- if task.task_id not in exclude_task_ids
- and (task.task_id, task.map_index) not in exclude_task_ids
- }
+ result.update(ti.key for ti in tis.all())
+
+ if exclude_task_ids_and_map_indexes:
+ result = set(
+ filter(
+ lambda key: (key.task_id, key.map_index) not in exclude_task_ids_and_map_indexes,
+ result,
+ )
+ )
if as_pk_tuple:
return result
- if result:
+ elif result:
# We've been asked for objects, lets combine it all back in to a result set
- ti_filters = TI.filter_for_tis(result)
- if ti_filters is not None:
- tis = session.query(TI).filter(ti_filters)
- elif exclude_task_ids is None:
- pass # Disable filter if not set.
- elif isinstance(next(iter(exclude_task_ids), None), str):
- tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
- else:
- tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids)))
+ tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
+
+ tis = session.query(TI).filter(TI.filter_for_tis(result))
+ elif exclude_task_ids_and_map_indexes:
+ tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids_and_map_indexes))
return tis
@@ -1620,7 +1637,7 @@ class DAG(LoggingMixin):
self,
*,
task_id: str,
- map_indexes: Optional[Collection[int]] = None,
+ map_indexes: Optional[Iterable[int]] = None,
execution_date: Optional[datetime] = None,
run_id: Optional[str] = None,
state: TaskInstanceState,
@@ -1636,8 +1653,7 @@ class DAG(LoggingMixin):
in failed or upstream_failed state.
:param task_id: Task ID of the TaskInstance
- :param map_indexes: Only set TaskInstance if its map_index matches.
- If None (default), all mapped TaskInstances of the task are set.
+ :param map_indexes: Task instance map_index to set the state of
:param execution_date: Execution date of the TaskInstance
:param run_id: The run_id of the TaskInstance
:param state: State to set the TaskInstance to
@@ -1662,18 +1678,13 @@ class DAG(LoggingMixin):
task = self.get_task(task_id)
task.dag = self
-
- tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]]
- task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]]
- if map_indexes is None:
- tasks_to_set_state = [task]
- task_ids_to_exclude_from_clear = {task_id}
- else:
- tasks_to_set_state = [(task, map_index) for map_index in map_indexes]
- task_ids_to_exclude_from_clear = {(task_id, map_index) for map_index in map_indexes}
+ if not map_indexes:
+ map_indexes = [-1]
+ task_map_indexes = [(task, map_index) for map_index in map_indexes]
+ task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes]
altered = set_state(
- tasks=tasks_to_set_state,
+ tasks=task_map_indexes,
execution_date=execution_date,
run_id=run_id,
upstream=upstream,
@@ -1703,13 +1714,12 @@ class DAG(LoggingMixin):
subdag.clear(
start_date=start_date,
end_date=end_date,
- map_indexes=map_indexes,
include_subdags=True,
include_parentdag=True,
only_failed=True,
session=session,
# Exclude the task itself from being cleared
- exclude_task_ids=task_ids_to_exclude_from_clear,
+ exclude_task_ids_and_map_indexes=task_id_map_indexes,
)
return altered
@@ -1767,7 +1777,8 @@ class DAG(LoggingMixin):
@provide_session
def clear(
self,
- task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None,
+ task_ids=None,
+ task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
only_failed: bool = False,
@@ -1782,13 +1793,14 @@ class DAG(LoggingMixin):
recursion_depth: int = 0,
max_recursion_depth: Optional[int] = None,
dag_bag: Optional["DagBag"] = None,
- exclude_task_ids: Union[FrozenSet[str], FrozenSet[Tuple[str, int]], None] = frozenset(),
+ exclude_task_ids: FrozenSet[str] = frozenset(),
+ exclude_task_ids_and_map_indexes: FrozenSet[Tuple[str, int]] = frozenset({}),
) -> Union[int, Iterable[TaskInstance]]:
"""
Clears a set of task instances associated with the current dag for
a specified date range.
- :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear
+ :param task_ids_and_map_indexes: List of tuple of task_id, map_index to clear
:param start_date: The minimum execution_date to clear
:param end_date: The maximum execution_date to clear
:param only_failed: Only clear failed tasks
@@ -1802,7 +1814,8 @@ class DAG(LoggingMixin):
:param dry_run: Find the tasks to clear but don't clear them.
:param session: The sqlalchemy session to use
:param dag_bag: The DagBag used to find the dags subdags (Optional)
- :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``)
+ :param exclude_task_ids: A set of ``task_id`` that should not be cleared
+ :param exclude_task_ids_and_map_indexes: A set of ``task_id``,``map_index``
tuples that should not be cleared
"""
if get_tis:
@@ -1832,9 +1845,12 @@ class DAG(LoggingMixin):
if only_running:
# Yes, having `+=` doesn't make sense, but this was the existing behaviour
state += [State.RUNNING]
+ if task_ids:
+ task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
tis = self._get_task_instances(
task_ids=task_ids,
+ task_ids_and_map_indexes=task_ids_and_map_indexes,
start_date=start_date,
end_date=end_date,
run_id=None,
@@ -1845,6 +1861,7 @@ class DAG(LoggingMixin):
session=session,
dag_bag=dag_bag,
exclude_task_ids=exclude_task_ids,
+ exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
)
if dry_run:
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 5e7f7a01e6..7b0f81af80 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1962,7 +1962,7 @@ class Airflow(AirflowBaseView):
start_date,
end_date,
origin,
- map_indexes=None,
+ task_id_map_index_list=None,
recursive=False,
confirmed=False,
only_failed=False,
@@ -1971,7 +1971,7 @@ class Airflow(AirflowBaseView):
count = dag.clear(
start_date=start_date,
end_date=end_date,
- map_indexes=map_indexes,
+ task_ids_and_map_indexes=task_id_map_index_list,
include_subdags=recursive,
include_parentdag=recursive,
only_failed=only_failed,
@@ -1984,7 +1984,7 @@ class Airflow(AirflowBaseView):
tis = dag.clear(
start_date=start_date,
end_date=end_date,
- map_indexes=map_indexes,
+ task_ids_and_map_indexes=task_id_map_index_list,
include_subdags=recursive,
include_parentdag=recursive,
only_failed=only_failed,
@@ -2026,9 +2026,14 @@ class Airflow(AirflowBaseView):
task_id = request.form.get('task_id')
origin = get_safe_url(request.form.get('origin'))
dag = current_app.dag_bag.get_dag(dag_id)
+
map_indexes = request.form.get('map_indexes')
- if map_indexes and not isinstance(map_indexes, list):
- map_indexes = list(map_indexes)
+ if map_indexes:
+ if not isinstance(map_indexes, list):
+ map_indexes = list(map_indexes)
+ else:
+ map_indexes = [-1]
+ task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes]
execution_date = request.form.get('execution_date')
execution_date = timezone.parse(execution_date)
@@ -2053,7 +2058,7 @@ class Airflow(AirflowBaseView):
start_date,
end_date,
origin,
- map_indexes=map_indexes,
+ task_id_map_index_list=task_id_map_indexes,
recursive=recursive,
confirmed=confirmed,
only_failed=only_failed,
@@ -2072,9 +2077,6 @@ class Airflow(AirflowBaseView):
dag_id = request.form.get('dag_id')
dag_run_id = request.form.get('dag_run_id')
confirmed = request.form.get('confirmed') == "true"
- map_indexes = request.form.get('map_indexes')
- if map_indexes and not isinstance(map_indexes, list):
- map_indexes = list(map_indexes)
dag = current_app.dag_bag.get_dag(dag_id)
dr = dag.get_dagrun(run_id=dag_run_id)
@@ -2085,7 +2087,6 @@ class Airflow(AirflowBaseView):
dag,
start_date,
end_date,
- map_indexes=map_indexes,
origin=None,
recursive=True,
confirmed=confirmed,
@@ -2339,8 +2340,11 @@ class Airflow(AirflowBaseView):
state = args.get('state')
origin = args.get('origin')
map_indexes = args.get('map_indexes')
- if map_indexes and not isinstance(map_indexes, list):
- map_indexes = list(map_indexes)
+ if map_indexes:
+ if not isinstance(map_indexes, list):
+ map_indexes = list(map_indexes)
+ else:
+ map_indexes = [-1]
upstream = to_boolean(args.get('upstream'))
downstream = to_boolean(args.get('downstream'))
@@ -2376,7 +2380,7 @@ class Airflow(AirflowBaseView):
from airflow.api.common.mark_tasks import set_state
to_be_altered = set_state(
- tasks=[task],
+ tasks=[(task, map_index) for map_index in map_indexes],
map_indexes=map_indexes,
run_id=dag_run_id,
upstream=upstream,
@@ -2418,8 +2422,11 @@ class Airflow(AirflowBaseView):
origin = get_safe_url(args.get('origin'))
dag_run_id = args.get('dag_run_id')
map_indexes = args.get('map_indexes')
- if map_indexes and not isinstance(map_indexes, list):
- map_indexes = list(map_indexes)
+ if map_indexes:
+ if not isinstance(map_indexes, list):
+ map_indexes = list(map_indexes)
+ else:
+ map_indexes = [-1]
upstream = to_boolean(args.get('upstream'))
downstream = to_boolean(args.get('downstream'))
@@ -2455,8 +2462,11 @@ class Airflow(AirflowBaseView):
origin = get_safe_url(args.get('origin'))
dag_run_id = args.get('dag_run_id')
map_indexes = args.get('map_indexes')
- if map_indexes and not isinstance(map_indexes, list):
- map_indexes = list(map_indexes)
+ if map_indexes:
+ if not isinstance(map_indexes, list):
+ map_indexes = list(map_indexes)
+ else:
+ map_indexes = [-1]
upstream = to_boolean(args.get('upstream'))
downstream = to_boolean(args.get('downstream'))
diff --git a/tests/api/common/test_mark_tasks.py b/tests/api/common/test_mark_tasks.py
index dedb624b1f..4c1d3c604b 100644
--- a/tests/api/common/test_mark_tasks.py
+++ b/tests/api/common/test_mark_tasks.py
@@ -439,12 +439,12 @@ class TestMarkTasks:
def test_mark_mapped_task_instance_state(self):
# set mapped task instance to success
snapshot = TestMarkTasks.snapshot_state(self.dag4, self.execution_dates)
- tasks = [self.dag4.get_task("consumer_literal")]
+ task = self.dag4.get_task("consumer_literal")
+ tasks = [(task, 0), (task, 1)]
map_indexes = [0, 1]
dr = DagRun.find(dag_id=self.dag4.dag_id, execution_date=self.execution_dates[0])[0]
altered = set_state(
tasks=tasks,
- map_indexes=map_indexes,
run_id=dr.run_id,
upstream=False,
downstream=False,
@@ -456,7 +456,7 @@ class TestMarkTasks:
assert len(altered) == 2
self.verify_state(
self.dag4,
- [task.task_id for task in tasks],
+ [task.task_id for task, _ in tasks],
[self.execution_dates[0]],
State.SUCCESS,
snapshot,
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 41219ef87b..ed2119a490 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1456,7 +1456,7 @@ class TestDag(unittest.TestCase):
session.flush()
dag.clear(
- map_indexes=[0],
+ task_ids_and_map_indexes=[(task_id, 0)],
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
dag_run_state=dag_run_state,