You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/20 19:00:47 UTC

[airflow] 05/19: fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI

This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit c5cc48f9f9161b187368d99b551c652e1da03de5
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Apr 13 20:16:56 2022 +0100

    fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
---
 airflow/api/common/mark_tasks.py    |  38 ++++-----
 airflow/models/dag.py               | 159 ++++++++++++++++++++----------------
 airflow/www/views.py                |  44 ++++++----
 tests/api/common/test_mark_tasks.py |   6 +-
 tests/models/test_dag.py            |   2 +-
 5 files changed, 133 insertions(+), 116 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 1d4709fb82..84fd48f4e4 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -18,9 +18,9 @@
 """Marks tasks APIs."""
 
 from datetime import datetime
-from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
 
-from sqlalchemy import or_
+from sqlalchemy import or_, tuple_
 from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm.session import Session as SASession
 
@@ -32,7 +32,6 @@ from airflow.operators.subdag import SubDagOperator
 from airflow.utils import timezone
 from airflow.utils.helpers import exactly_one
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
@@ -79,7 +78,7 @@ def _create_dagruns(
 @provide_session
 def set_state(
     *,
-    tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]],
+    tasks: Union[Iterable[Operator], Iterable[Tuple[Operator, int]]],
     run_id: Optional[str] = None,
     execution_date: Optional[datetime] = None,
     upstream: bool = False,
@@ -97,7 +96,7 @@ def set_state(
     tasks that did not exist. It will not create dag runs that are missing
     on the schedule (but it will as for subdag dag runs if needed).
 
-    :param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
+    :param tasks: the iterable of tasks or task, map_index tuple from which to work.
         task.task.dag needs to be set
     :param run_id: the run_id of the dagrun to start looking from
     :param execution_date: the execution date from which to start looking(deprecated)
@@ -120,7 +119,9 @@ def set_state(
     if execution_date and not timezone.is_localized(execution_date):
         raise ValueError(f"Received non-localized date {execution_date}")
 
-    task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks}
+    t_dags = {task.dag for task in tasks if not isinstance(task, tuple)}
+    t_dags_2 = {item[0].dag for item in tasks if isinstance(item, tuple)}
+    task_dags = t_dags | t_dags_2
     if len(task_dags) > 1:
         raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
     dag = next(iter(task_dags))
@@ -135,12 +136,6 @@ def set_state(
     dag_run_ids = get_run_ids(dag, run_id, future, past)
     task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
     task_ids = [task_id for task_id, _ in task_id_map_index_list]
-    # check if task_id_map_index_list contains map_index of None
-    # if it contains None, there was no map_index supplied for the task
-    for _, index in task_id_map_index_list:
-        if index is None:
-            task_id_map_index_list = [task_id for task_id, _ in task_id_map_index_list]
-            break
 
     confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
     confirmed_dates = [info.logical_date for info in confirmed_infos]
@@ -187,26 +182,20 @@ def get_all_dag_task_query(
     dag: DAG,
     session: SASession,
     state: TaskInstanceState,
-    task_ids: Union[List[str], List[Tuple[str, int]]],
+    task_id_map_index_list: List[Tuple[str, int]],
     confirmed_dates: Iterable[datetime],
 ):
     """Get all tasks of the main dag that will be affected by a state change"""
-    is_string_list = isinstance(task_ids[0], str)
     qry_dag = (
         session.query(TaskInstance)
         .join(TaskInstance.dag_run)
         .filter(
             TaskInstance.dag_id == dag.dag_id,
             DagRun.execution_date.in_(confirmed_dates),
+            tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_id_map_index_list),
         )
-    )
-
-    if is_string_list:
-        qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids))
-    else:
-        qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids))
-    qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
-        contains_eager(TaskInstance.dag_run)
+        .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
+        .options(contains_eager(TaskInstance.dag_run))
     )
     return qry_dag
 
@@ -282,13 +271,14 @@ def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str]) -> Iterator[_DagR
         yield _DagRunInfo(dag_run.logical_date, dag.get_run_data_interval(dag_run))
 
 
-def find_task_relatives(tasks, downstream, upstream):
+@provide_session
+def find_task_relatives(tasks, downstream, upstream, session: SASession = NEW_SESSION):
     """Yield task ids and optionally ancestor and descendant ids."""
     for item in tasks:
         if isinstance(item, tuple):
             task, map_index = item
         else:
-            task, map_index = item, None
+            task, map_index = item, -1
         yield task.task_id, map_index
         if downstream:
             for relative in task.get_flat_relatives(upstream=False):
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index e9c33acb72..931fd469d7 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -39,7 +39,6 @@ from typing import (
     Iterable,
     List,
     Optional,
-    Sequence,
     Set,
     Tuple,
     Type,
@@ -52,7 +51,7 @@ import jinja2
 import pendulum
 from dateutil.relativedelta import relativedelta
 from pendulum.tz.timezone import Timezone
-from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_
+from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_, tuple_
 from sqlalchemy.orm import backref, joinedload, relationship
 from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.session import Session
@@ -85,7 +84,7 @@ from airflow.utils.file import correct_maybe_zipped
 from airflow.utils.helpers import exactly_one, validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
 
@@ -1341,33 +1340,47 @@ class DAG(LoggingMixin):
             start_date = (timezone.utcnow() - timedelta(30)).replace(
                 hour=0, minute=0, second=0, microsecond=0
             )
-        query = self._get_task_instances(
-            task_ids=None,
-            start_date=start_date,
-            end_date=end_date,
-            run_id=None,
-            state=state or (),
-            include_subdags=False,
-            include_parentdag=False,
-            include_dependent_dags=False,
-            exclude_task_ids=(),
-            session=session,
+
+        if state is None:
+            state = []
+
+        return (
+            cast(
+                Query,
+                self._get_task_instances(
+                    task_ids=None,
+                    task_ids_and_map_indexes=None,
+                    start_date=start_date,
+                    end_date=end_date,
+                    run_id=None,
+                    state=state,
+                    include_subdags=False,
+                    include_parentdag=False,
+                    include_dependent_dags=False,
+                    exclude_task_ids=cast(List[str], []),
+                    exclude_task_ids_and_map_indexes=None,
+                    session=session,
+                ),
+            )
+            .order_by(DagRun.execution_date)
+            .all()
         )
-        return cast(Query, query).order_by(DagRun.execution_date).all()
 
     @overload
     def _get_task_instances(
         self,
         *,
-        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        task_ids: Iterable[str],
+        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+        state: Union[TaskInstanceState, List[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        exclude_task_ids: Collection[str],
+        exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
         session: Session,
         dag_bag: Optional["DagBag"] = ...,
     ) -> Iterable[TaskInstance]:
@@ -1377,16 +1390,18 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        task_ids: Iterable[str],
+        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
         as_pk_tuple: Literal[True],
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+        state: Union[TaskInstanceState, List[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        exclude_task_ids: Collection[str],
+        exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
         session: Session,
         dag_bag: Optional["DagBag"] = ...,
         recursion_depth: int = ...,
@@ -1398,16 +1413,18 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        task_ids: Iterable[str],
+        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
         as_pk_tuple: Literal[True, None] = None,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+        state: Union[TaskInstanceState, List[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        exclude_task_ids: Collection[str],
+        exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
         session: Session,
         dag_bag: Optional["DagBag"] = None,
         recursion_depth: int = 0,
@@ -1431,6 +1448,11 @@ class DAG(LoggingMixin):
             tis = session.query(TaskInstance)
         tis = tis.join(TaskInstance.dag_run)
 
+        if task_ids is not None:  # task not mapped
+            task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
+        if exclude_task_ids and len(exclude_task_ids) > 0:  # task not mapped
+            exclude_task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
+
         if include_subdags:
             # Crafting the right filter for dag_id and task_ids combo
             conditions = []
@@ -1445,13 +1467,10 @@ class DAG(LoggingMixin):
             tis = tis.filter(TaskInstance.run_id == run_id)
         if start_date:
             tis = tis.filter(DagRun.execution_date >= start_date)
-
-        if task_ids is None:
-            pass  # Disable filter if not set.
-        elif isinstance(next(iter(task_ids), None), str):
-            tis = tis.filter(TI.task_id.in_(task_ids))
-        else:
-            tis = tis.filter(tuple_in_condition((TI.task_id, TI.map_index), task_ids))
+        if task_ids_and_map_indexes:
+            tis = tis.filter(
+                tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids_and_map_indexes)
+            )
 
         # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
         if end_date or not self.allow_future_exec_dates:
@@ -1490,6 +1509,7 @@ class DAG(LoggingMixin):
             result.update(
                 p_dag._get_task_instances(
                     task_ids=task_ids,
+                    task_ids_and_map_indexes=task_ids_and_map_indexes,
                     start_date=start_date,
                     end_date=end_date,
                     run_id=None,
@@ -1499,6 +1519,7 @@ class DAG(LoggingMixin):
                     include_dependent_dags=include_dependent_dags,
                     as_pk_tuple=True,
                     exclude_task_ids=exclude_task_ids,
+                    exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
                     session=session,
                     dag_bag=dag_bag,
                     recursion_depth=recursion_depth,
@@ -1566,7 +1587,7 @@ class DAG(LoggingMixin):
                     )
                     result.update(
                         downstream._get_task_instances(
-                            task_ids=None,
+                            task_ids_and_map_indexes=None,
                             run_id=tii.run_id,
                             start_date=None,
                             end_date=None,
@@ -1575,7 +1596,7 @@ class DAG(LoggingMixin):
                             include_dependent_dags=include_dependent_dags,
                             include_parentdag=False,
                             as_pk_tuple=True,
-                            exclude_task_ids=exclude_task_ids,
+                            exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
                             dag_bag=dag_bag,
                             session=session,
                             recursion_depth=recursion_depth + 1,
@@ -1589,29 +1610,25 @@ class DAG(LoggingMixin):
             if as_pk_tuple:
                 result.update(TaskInstanceKey(*cols) for cols in tis.all())
             else:
-                result.update(ti.key for ti in tis)
-
-            if exclude_task_ids is not None:
-                result = {
-                    task
-                    for task in result
-                    if task.task_id not in exclude_task_ids
-                    and (task.task_id, task.map_index) not in exclude_task_ids
-                }
+                result.update(ti.key for ti in tis.all())
+
+            if exclude_task_ids_and_map_indexes:
+                result = set(
+                    filter(
+                        lambda key: (key.task_id, key.map_index) not in exclude_task_ids_and_map_indexes,
+                        result,
+                    )
+                )
 
         if as_pk_tuple:
             return result
-        if result:
+        elif result:
             # We've been asked for objects, lets combine it all back in to a result set
-            ti_filters = TI.filter_for_tis(result)
-            if ti_filters is not None:
-                tis = session.query(TI).filter(ti_filters)
-        elif exclude_task_ids is None:
-            pass  # Disable filter if not set.
-        elif isinstance(next(iter(exclude_task_ids), None), str):
-            tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
-        else:
-            tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids)))
+            tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
+
+            tis = session.query(TI).filter(TI.filter_for_tis(result))
+        elif exclude_task_ids_and_map_indexes:
+            tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids_and_map_indexes))
 
         return tis
 
@@ -1620,7 +1637,7 @@ class DAG(LoggingMixin):
         self,
         *,
         task_id: str,
-        map_indexes: Optional[Collection[int]] = None,
+        map_indexes: Optional[Iterable[int]] = None,
         execution_date: Optional[datetime] = None,
         run_id: Optional[str] = None,
         state: TaskInstanceState,
@@ -1636,8 +1653,7 @@ class DAG(LoggingMixin):
         in failed or upstream_failed state.
 
         :param task_id: Task ID of the TaskInstance
-        :param map_indexes: Only set TaskInstance if its map_index matches.
-            If None (default), all mapped TaskInstances of the task are set.
+        :param map_indexes: Task instance map_index to set the state of
         :param execution_date: Execution date of the TaskInstance
         :param run_id: The run_id of the TaskInstance
         :param state: State to set the TaskInstance to
@@ -1662,18 +1678,13 @@ class DAG(LoggingMixin):
 
         task = self.get_task(task_id)
         task.dag = self
-
-        tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]]
-        task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]]
-        if map_indexes is None:
-            tasks_to_set_state = [task]
-            task_ids_to_exclude_from_clear = {task_id}
-        else:
-            tasks_to_set_state = [(task, map_index) for map_index in map_indexes]
-            task_ids_to_exclude_from_clear = {(task_id, map_index) for map_index in map_indexes}
+        if not map_indexes:
+            map_indexes = [-1]
+        task_map_indexes = [(task, map_index) for map_index in map_indexes]
+        task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes]
 
         altered = set_state(
-            tasks=tasks_to_set_state,
+            tasks=task_map_indexes,
             execution_date=execution_date,
             run_id=run_id,
             upstream=upstream,
@@ -1703,13 +1714,12 @@ class DAG(LoggingMixin):
         subdag.clear(
             start_date=start_date,
             end_date=end_date,
-            map_indexes=map_indexes,
             include_subdags=True,
             include_parentdag=True,
             only_failed=True,
             session=session,
             # Exclude the task itself from being cleared
-            exclude_task_ids=task_ids_to_exclude_from_clear,
+            exclude_task_ids_and_map_indexes=task_id_map_indexes,
         )
 
         return altered
@@ -1767,7 +1777,8 @@ class DAG(LoggingMixin):
     @provide_session
     def clear(
         self,
-        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None,
+        task_ids=None,
+        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]] = None,
         start_date: Optional[datetime] = None,
         end_date: Optional[datetime] = None,
         only_failed: bool = False,
@@ -1782,13 +1793,14 @@ class DAG(LoggingMixin):
         recursion_depth: int = 0,
         max_recursion_depth: Optional[int] = None,
         dag_bag: Optional["DagBag"] = None,
-        exclude_task_ids: Union[FrozenSet[str], FrozenSet[Tuple[str, int]], None] = frozenset(),
+        exclude_task_ids: FrozenSet[str] = frozenset(),
+        exclude_task_ids_and_map_indexes: FrozenSet[Tuple[str, int]] = frozenset({}),
     ) -> Union[int, Iterable[TaskInstance]]:
         """
         Clears a set of task instances associated with the current dag for
         a specified date range.
 
-        :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear
+        :param task_ids_and_map_indexes: List of tuple of task_id, map_index to clear
         :param start_date: The minimum execution_date to clear
         :param end_date: The maximum execution_date to clear
         :param only_failed: Only clear failed tasks
@@ -1802,7 +1814,8 @@ class DAG(LoggingMixin):
         :param dry_run: Find the tasks to clear but don't clear them.
         :param session: The sqlalchemy session to use
         :param dag_bag: The DagBag used to find the dags subdags (Optional)
-        :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``)
+        :param exclude_task_ids: A set of ``task_id`` that should not be cleared
+        :param exclude_task_ids_and_map_indexes: A set of ``task_id``,``map_index``
             tuples that should not be cleared
         """
         if get_tis:
@@ -1832,9 +1845,12 @@ class DAG(LoggingMixin):
         if only_running:
             # Yes, having `+=` doesn't make sense, but this was the existing behaviour
             state += [State.RUNNING]
+        if task_ids:
+            task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
 
         tis = self._get_task_instances(
             task_ids=task_ids,
+            task_ids_and_map_indexes=task_ids_and_map_indexes,
             start_date=start_date,
             end_date=end_date,
             run_id=None,
@@ -1845,6 +1861,7 @@ class DAG(LoggingMixin):
             session=session,
             dag_bag=dag_bag,
             exclude_task_ids=exclude_task_ids,
+            exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
         )
 
         if dry_run:
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 5e7f7a01e6..7b0f81af80 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1962,7 +1962,7 @@ class Airflow(AirflowBaseView):
         start_date,
         end_date,
         origin,
-        map_indexes=None,
+        task_id_map_index_list=None,
         recursive=False,
         confirmed=False,
         only_failed=False,
@@ -1971,7 +1971,7 @@ class Airflow(AirflowBaseView):
             count = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                map_indexes=map_indexes,
+                task_ids_and_map_indexes=task_id_map_index_list,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -1984,7 +1984,7 @@ class Airflow(AirflowBaseView):
             tis = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                map_indexes=map_indexes,
+                task_ids_and_map_indexes=task_id_map_index_list,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -2026,9 +2026,14 @@ class Airflow(AirflowBaseView):
         task_id = request.form.get('task_id')
         origin = get_safe_url(request.form.get('origin'))
         dag = current_app.dag_bag.get_dag(dag_id)
+
         map_indexes = request.form.get('map_indexes')
-        if map_indexes and not isinstance(map_indexes, list):
-            map_indexes = list(map_indexes)
+        if map_indexes:
+            if not isinstance(map_indexes, list):
+                map_indexes = list(map_indexes)
+        else:
+            map_indexes = [-1]
+        task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes]
 
         execution_date = request.form.get('execution_date')
         execution_date = timezone.parse(execution_date)
@@ -2053,7 +2058,7 @@ class Airflow(AirflowBaseView):
             start_date,
             end_date,
             origin,
-            map_indexes=map_indexes,
+            task_id_map_index_list=task_id_map_indexes,
             recursive=recursive,
             confirmed=confirmed,
             only_failed=only_failed,
@@ -2072,9 +2077,6 @@ class Airflow(AirflowBaseView):
         dag_id = request.form.get('dag_id')
         dag_run_id = request.form.get('dag_run_id')
         confirmed = request.form.get('confirmed') == "true"
-        map_indexes = request.form.get('map_indexes')
-        if map_indexes and not isinstance(map_indexes, list):
-            map_indexes = list(map_indexes)
 
         dag = current_app.dag_bag.get_dag(dag_id)
         dr = dag.get_dagrun(run_id=dag_run_id)
@@ -2085,7 +2087,6 @@ class Airflow(AirflowBaseView):
             dag,
             start_date,
             end_date,
-            map_indexes=map_indexes,
             origin=None,
             recursive=True,
             confirmed=confirmed,
@@ -2339,8 +2340,11 @@ class Airflow(AirflowBaseView):
         state = args.get('state')
         origin = args.get('origin')
         map_indexes = args.get('map_indexes')
-        if map_indexes and not isinstance(map_indexes, list):
-            map_indexes = list(map_indexes)
+        if map_indexes:
+            if not isinstance(map_indexes, list):
+                map_indexes = list(map_indexes)
+        else:
+            map_indexes = [-1]
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2376,7 +2380,7 @@ class Airflow(AirflowBaseView):
         from airflow.api.common.mark_tasks import set_state
 
         to_be_altered = set_state(
-            tasks=[task],
+            tasks=[(task, map_index) for map_index in map_indexes],
             map_indexes=map_indexes,
             run_id=dag_run_id,
             upstream=upstream,
@@ -2418,8 +2422,11 @@ class Airflow(AirflowBaseView):
         origin = get_safe_url(args.get('origin'))
         dag_run_id = args.get('dag_run_id')
         map_indexes = args.get('map_indexes')
-        if map_indexes and not isinstance(map_indexes, list):
-            map_indexes = list(map_indexes)
+        if map_indexes:
+            if not isinstance(map_indexes, list):
+                map_indexes = list(map_indexes)
+        else:
+            map_indexes = [-1]
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2455,8 +2462,11 @@ class Airflow(AirflowBaseView):
         origin = get_safe_url(args.get('origin'))
         dag_run_id = args.get('dag_run_id')
         map_indexes = args.get('map_indexes')
-        if map_indexes and not isinstance(map_indexes, list):
-            map_indexes = list(map_indexes)
+        if map_indexes:
+            if not isinstance(map_indexes, list):
+                map_indexes = list(map_indexes)
+        else:
+            map_indexes = [-1]
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
diff --git a/tests/api/common/test_mark_tasks.py b/tests/api/common/test_mark_tasks.py
index dedb624b1f..4c1d3c604b 100644
--- a/tests/api/common/test_mark_tasks.py
+++ b/tests/api/common/test_mark_tasks.py
@@ -439,12 +439,12 @@ class TestMarkTasks:
     def test_mark_mapped_task_instance_state(self):
         # set mapped task instance to success
         snapshot = TestMarkTasks.snapshot_state(self.dag4, self.execution_dates)
-        tasks = [self.dag4.get_task("consumer_literal")]
+        task = self.dag4.get_task("consumer_literal")
+        tasks = [(task, 0), (task, 1)]
         map_indexes = [0, 1]
         dr = DagRun.find(dag_id=self.dag4.dag_id, execution_date=self.execution_dates[0])[0]
         altered = set_state(
             tasks=tasks,
-            map_indexes=map_indexes,
             run_id=dr.run_id,
             upstream=False,
             downstream=False,
@@ -456,7 +456,7 @@ class TestMarkTasks:
         assert len(altered) == 2
         self.verify_state(
             self.dag4,
-            [task.task_id for task in tasks],
+            [task.task_id for task, _ in tasks],
             [self.execution_dates[0]],
             State.SUCCESS,
             snapshot,
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 41219ef87b..ed2119a490 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1456,7 +1456,7 @@ class TestDag(unittest.TestCase):
         session.flush()
 
         dag.clear(
-            map_indexes=[0],
+            task_ids_and_map_indexes=[(task_id, 0)],
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=1),
             dag_run_state=dag_run_state,