You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/02/04 14:25:13 UTC

[airflow] branch main updated: Make `airflow dags test` be able to execute Mapped Tasks (#21210)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6fc6edf  Make `airflow dags test` be able to execute Mapped Tasks (#21210)
6fc6edf is described below

commit 6fc6edf6af7f676bfa54ff3a2e6e6d2edb938f2e
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Fri Feb 4 14:24:32 2022 +0000

    Make `airflow dags test` be able to execute Mapped Tasks (#21210)
    
    * Make `airflow dags test` be able to execute Mapped Tasks
    
    In order to do this there were two steps required:
    
    - The BackfillJob needs to know about mapped tasks, both to expand them,
      and in order to update it's TI tracking
    - The DebugExecutor needed to "unmap" the mapped task to get the real
      operator back
    
    I was testing this with the following dag:
    
    ```
    from airflow import DAG
    from airflow.decorators import task
    from airflow.operators.python import PythonOperator
    import pendulum
    
    @task
    def make_list():
        return list(map(lambda a: f'echo "{a!r}"', [1, 2, {'a': 'b'}]))
    
    def consumer(*args):
         print(repr(args))
    
    with DAG(dag_id='maptest', start_date=pendulum.DateTime(2022, 1, 18)) as dag:
        PythonOperator(task_id='consumer', python_callable=consumer).map(op_args=make_list())
    ```
    
    It can't "unmap" decorated operators successfully yet, so we're using
    old-school PythonOperator
    
    We also just pass the whole value to the operator, not just the current
    mapping value(s)
    
    * Always have a `task_group` property on DAGNodes
    
    And since TaskGroup is a DAGNode, we don't need to store parent group
    directly anymore -- it'll already be stored
    
    * Add "integation" tests for running mapped tasks via BackfillJob
    
    * Only show "Map Index" in Backfill report when relevant
    
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
 airflow/cli/commands/task_command.py               |   2 +
 airflow/executors/debug_executor.py                |   2 +
 airflow/executors/kubernetes_executor.py           |   2 +-
 airflow/jobs/backfill_job.py                       | 117 ++++++++++--------
 airflow/jobs/local_task_job.py                     |   6 +
 airflow/jobs/scheduler_job.py                      |   2 +-
 airflow/models/baseoperator.py                     | 134 ++++++++++++---------
 airflow/models/taskinstance.py                     |  51 +++++---
 airflow/models/taskmixin.py                        |  52 +++++++-
 airflow/serialization/serialized_objects.py        |  22 ++--
 .../ti_deps/deps/mapped_task_expanded.py           |  16 ++-
 airflow/utils/task_group.py                        |  33 ++---
 .../__init__.py => dags/test_mapped_classic.py}    |  20 ++-
 tests/executors/test_kubernetes_executor.py        |   5 +-
 tests/jobs/test_backfill_job.py                    |  42 +++++--
 tests/models/__init__.py                           |   4 +-
 tests/models/test_baseoperator.py                  |  20 ++-
 tests/models/test_dag.py                           |   2 +-
 tests/models/test_taskinstance.py                  |   2 +-
 tests/serialization/test_dag_serialization.py      |   5 +
 tests/test_utils/mock_executor.py                  |   4 +-
 21 files changed, 366 insertions(+), 177 deletions(-)

diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index 537fab0..1b5208f 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -224,6 +224,8 @@ RAW_TASK_UNSUPPORTED_OPTION = [
 
 def _run_raw_task(args, ti: TaskInstance) -> None:
     """Runs the main task handling code"""
+    if ti.task.is_mapped:
+        ti.task = ti.task.unmap()
     ti._run_raw_task(
         mark_success=args.mark_success,
         job_id=args.job_id,
diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py
index 865186d..0ab5f35 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -76,6 +76,8 @@ class DebugExecutor(BaseExecutor):
         key = ti.key
         try:
             params = self.tasks_params.pop(ti.key, {})
+            if ti.task.is_mapped:
+                ti.task = ti.task.unmap()
             ti._run_raw_task(job_id=ti.job_id, **params)
             self.change_state(key, State.SUCCESS)
             ti._run_finished_callback()
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index 1071a3a..ef671eb 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -296,7 +296,7 @@ class AirflowKubernetesScheduler(LoggingMixin):
         """
         self.log.info('Kubernetes job is %s', str(next_job).replace("\n", " "))
         key, command, kube_executor_config, pod_template_file = next_job
-        dag_id, task_id, run_id, try_number = key
+        dag_id, task_id, run_id, try_number, _ = key
 
         if command[0:3] != ["airflow", "tasks", "run"]:
             raise ValueError('The command must start with ["airflow", "tasks", "run"].')
diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index 406c2ea..10a5d08 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -18,9 +18,9 @@
 #
 
 import time
-from collections import OrderedDict
-from typing import Optional, Set
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple
 
+import attr
 import pendulum
 from sqlalchemy.orm.session import Session, make_transient
 from tabulate import tabulate
@@ -48,6 +48,9 @@ from airflow.utils.session import provide_session
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
+if TYPE_CHECKING:
+    from airflow.models.baseoperator import MappedOperator
+
 
 class BackfillJob(BaseJob):
     """
@@ -60,6 +63,7 @@ class BackfillJob(BaseJob):
 
     __mapper_args__ = {'polymorphic_identity': 'BackfillJob'}
 
+    @attr.define
     class _DagRunTaskStatus:
         """
         Internal status of the backfill job. This class is intended to be instantiated
@@ -83,32 +87,17 @@ class BackfillJob(BaseJob):
         :param total_runs: Number of total dag runs able to run
         """
 
-        # TODO(edgarRd): AIRFLOW-1444: Add consistency check on counts
-        def __init__(
-            self,
-            to_run=None,
-            running=None,
-            skipped=None,
-            succeeded=None,
-            failed=None,
-            not_ready=None,
-            deadlocked=None,
-            active_runs=None,
-            executed_dag_run_dates=None,
-            finished_runs=0,
-            total_runs=0,
-        ):
-            self.to_run = to_run or OrderedDict()
-            self.running = running or {}
-            self.skipped = skipped or set()
-            self.succeeded = succeeded or set()
-            self.failed = failed or set()
-            self.not_ready = not_ready or set()
-            self.deadlocked = deadlocked or set()
-            self.active_runs = active_runs or []
-            self.executed_dag_run_dates = executed_dag_run_dates or set()
-            self.finished_runs = finished_runs
-            self.total_runs = total_runs
+        to_run: Dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict)
+        running: Dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict)
+        skipped: Set[TaskInstanceKey] = attr.ib(factory=set)
+        succeeded: Set[TaskInstanceKey] = attr.ib(factory=set)
+        failed: Set[TaskInstanceKey] = attr.ib(factory=set)
+        not_ready: Set[TaskInstanceKey] = attr.ib(factory=set)
+        deadlocked: Set[TaskInstance] = attr.ib(factory=set)
+        active_runs: List[DagRun] = attr.ib(factory=list)
+        executed_dag_run_dates: Set[pendulum.DateTime] = attr.ib(factory=set)
+        finished_runs: int = 0
+        total_runs: int = 0
 
     def __init__(
         self,
@@ -167,7 +156,6 @@ class BackfillJob(BaseJob):
         self.run_at_least_once = run_at_least_once
         super().__init__(*args, **kwargs)
 
-    @provide_session
     def _update_counters(self, ti_status, session=None):
         """
         Updates the counters per state of the tasks that were running. Can re-add
@@ -234,14 +222,22 @@ class BackfillJob(BaseJob):
             session.query(TI).filter(filter_for_tis).update(
                 values={TI.state: TaskInstanceState.SCHEDULED}, synchronize_session=False
             )
+            session.flush()
 
-    def _manage_executor_state(self, running):
+    def _manage_executor_state(
+        self, running, session
+    ) -> Iterator[Tuple["MappedOperator", str, Sequence[TaskInstance]]]:
         """
         Checks if the executor agrees with the state of task instances
-        that are running
+        that are running.
+
+        Expands downstream mapped tasks when necessary
 
         :param running: dict of key, task to verify
+        :return: An iterable of expanded TaskInstance per MappedTask
         """
+        from airflow.models.baseoperator import MappedOperator
+
         executor = self.executor
 
         # TODO: query all instead of refresh from db
@@ -266,6 +262,11 @@ class BackfillJob(BaseJob):
                 )
                 self.log.error(msg)
                 ti.handle_failure_with_callback(error=msg)
+                continue
+            if ti.state not in self.STATES_COUNT_AS_RUNNING:
+                for node in ti.task.mapped_dependants():
+                    assert isinstance(node, MappedOperator)
+                    yield node, ti.run_id, node.expand_mapped_task(ti, session)
 
     @provide_session
     def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = None):
@@ -409,7 +410,6 @@ class BackfillJob(BaseJob):
             # or leaf to root, as otherwise tasks might be
             # determined deadlocked while they are actually
             # waiting for their upstream to finish
-            @provide_session
             def _per_task_process(key, ti: TaskInstance, session=None):
                 ti.refresh_from_db(lock_for_update=True, session=session)
 
@@ -577,7 +577,8 @@ class BackfillJob(BaseJob):
                                     "Not scheduling since Task concurrency limit is reached."
                                 )
 
-                        _per_task_process(key, ti)
+                        _per_task_process(key, ti, session)
+                        session.commit()
             except (NoAvailablePoolSlot, DagConcurrencyLimitReached, TaskConcurrencyLimitReached) as e:
                 self.log.debug(e)
 
@@ -597,11 +598,23 @@ class BackfillJob(BaseJob):
                 ti_status.deadlocked.update(ti_status.to_run.values())
                 ti_status.to_run.clear()
 
-            # check executor state
-            self._manage_executor_state(ti_status.running)
+            # check executor state -- and expand any mapped TIs
+            for node, run_id, mapped_tis in self._manage_executor_state(ti_status.running, session):
+
+                def to_keep(key: TaskInstanceKey) -> bool:
+                    if key.dag_id != node.dag_id or key.task_id != node.task_id or key.run_id != run_id:
+                        # For another Dag/Task/Run -- don't remove
+                        return True
+                    return False
+
+                # remove the old unmapped TIs for node -- they have been replaced with the mapped TIs
+                ti_status.to_run = {key: ti for (key, ti) in ti_status.to_run.items() if to_keep(key)}
+
+                ti_status.to_run.update({ti.key: ti for ti in mapped_tis})
 
             # update the task counters
-            self._update_counters(ti_status=ti_status)
+            self._update_counters(ti_status=ti_status, session=session)
+            session.commit()
 
             # update dag run state
             _dag_runs = ti_status.active_runs[:]
@@ -613,25 +626,33 @@ class BackfillJob(BaseJob):
                     executed_run_dates.append(run.execution_date)
 
             self._log_progress(ti_status)
+            session.commit()
 
         # return updated status
         return executed_run_dates
 
     @provide_session
-    def _collect_errors(self, ti_status, session=None):
-        def tabulate_ti_keys_set(set_ti_keys: Set[TaskInstanceKey]) -> str:
+    def _collect_errors(self, ti_status: _DagRunTaskStatus, session=None):
+        def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str:
             # Sorting by execution date first
-            sorted_ti_keys = sorted(
-                set_ti_keys,
-                key=lambda ti_key: (ti_key.run_id, ti_key.dag_id, ti_key.task_id, ti_key.try_number),
+            sorted_ti_keys: Any = sorted(
+                ti_keys,
+                key=lambda ti_key: (
+                    ti_key.run_id,
+                    ti_key.dag_id,
+                    ti_key.task_id,
+                    ti_key.map_index,
+                    ti_key.try_number,
+                ),
             )
-            return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Run ID", "Try number"])
 
-        def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str:
-            # Sorting by execution date first
-            sorted_tis = sorted(set_tis, key=lambda ti: (ti.run_id, ti.dag_id, ti.task_id, ti.try_number))
-            tis_values = ((ti.dag_id, ti.task_id, ti.run_id, ti.try_number) for ti in sorted_tis)
-            return tabulate(tis_values, headers=["DAG ID", "Task ID", "Run ID", "Try number"])
+            if all(key.map_index == -1 for key in ti_keys):
+                headers = ["DAG ID", "Task ID", "Run ID", "Try number"]
+                sorted_ti_keys = map(lambda k: k[0:4], sorted_ti_keys)
+            else:
+                headers = ["DAG ID", "Task ID", "Run ID", "Map Index", "Try number"]
+
+            return tabulate(sorted_ti_keys, headers=headers)
 
         err = ''
         if ti_status.failed:
@@ -667,7 +688,7 @@ class BackfillJob(BaseJob):
             err += '\n\nThese tasks are skipped:\n'
             err += tabulate_ti_keys_set(ti_status.skipped)
             err += '\n\nThese tasks are deadlocked:\n'
-            err += tabulate_tis_set(ti_status.deadlocked)
+            err += tabulate_ti_keys_set([ti.key for ti in ti_status.deadlocked])
 
         return err
 
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index c0255d7..05ee533 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -104,6 +104,12 @@ class LocalTaskJob(BaseJob):
         try:
             self.task_runner.start()
 
+            # Unmap the task _after_ it has forked/execed. (This is a bit of a kludge, but if we unmap before
+            # fork, then the "run_raw_task" command will see the mapping index and an Non-mapped task and
+            # fail)
+            if self.task_instance.task.is_mapped:
+                self.task_instance.task = self.task_instance.task.unmap()
+
             heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
 
             # task callback invocation happens either here or in
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index cbda16e..7a6e3ef 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -534,7 +534,7 @@ class SchedulerJob(BaseJob):
         """Respond to executor events."""
         if not self.processor_agent:
             raise ValueError("Processor agent is not started.")
-        ti_primary_key_to_try_number_map: Dict[Tuple[str, str, str], int] = {}
+        ti_primary_key_to_try_number_map: Dict[Tuple[str, str, str, int], int] = {}
         event_buffer = self.executor.get_event_buffer()
         tis_with_right_state: List[TaskInstanceKey] = []
 
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index d51dda8..34c8412 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -50,7 +50,7 @@ import attr
 import jinja2
 import pendulum
 from dateutil.relativedelta import relativedelta
-from sqlalchemy import or_
+from sqlalchemy import func, or_
 from sqlalchemy.orm import Session
 from sqlalchemy.orm.exc import NoResultFound
 
@@ -66,6 +66,7 @@ from airflow.models.taskmixin import DAGNode, DependencyMixin
 from airflow.models.xcom import XCOM_RETURN_KEY
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
+from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
 from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
 from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
 from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
@@ -247,7 +248,12 @@ class BaseOperatorMeta(abc.ABCMeta):
         # Validate that the args we passed are known -- at call/DAG parse time, not run time!
         _validate_kwarg_names_for_mapping(operator_class, "partial", kwargs)
         return MappedOperator(
-            task_id=task_id, operator_class=operator_class, dag=dag, partial_kwargs=kwargs, mapped_kwargs={}
+            task_id=task_id,
+            operator_class=operator_class,
+            dag=dag,
+            partial_kwargs=kwargs,
+            mapped_kwargs={},
+            deps=MappedOperator._deps(operator_class.deps),
         )
 
 
@@ -1459,9 +1465,7 @@ class BaseOperator(Operator, LoggingMixin, DAGNode, metaclass=BaseOperatorMeta):
         """Return if this operator can use smart service. Default False."""
         return False
 
-    @property
-    def is_mapped(self) -> bool:
-        return False
+    is_mapped: ClassVar[bool] = False
 
     @property
     def inherits_from_dummy_operator(self):
@@ -1491,38 +1495,10 @@ class BaseOperator(Operator, LoggingMixin, DAGNode, metaclass=BaseOperatorMeta):
     def map(self, **kwargs) -> "MappedOperator":
         return MappedOperator.from_operator(self, kwargs)
 
-    def has_mapped_dependants(self) -> bool:
-        """Whether any downstream dependencies depend on this task for mapping.
-
-        For now, this walks the entire DAG to find mapped nodes that has this
-        current task as an upstream. We cannot use ``downstream_list`` since it
-        only contains operators, not task groups. In the future, we should
-        provide a way to record an DAG node's all downstream nodes instead.
-        """
-        from airflow.utils.task_group import MappedTaskGroup, TaskGroup
-
-        if not self.has_dag():
-            return False
-
-        def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]:
-            """Recursively walk children in a task group.
-
-            This yields all direct children (including both tasks and task
-            groups), and all children of any task groups.
-            """
-            for key, child in group.children.items():
-                yield key, child
-                if isinstance(child, TaskGroup):
-                    yield from _walk_group(child)
-
-        for key, child in _walk_group(self.dag.task_group):
-            if key == self.task_id:
-                continue
-            if not isinstance(child, (MappedOperator, MappedTaskGroup)):
-                continue
-            if self.task_id in child.upstream_task_ids:
-                return True
-        return False
+    def unmap(self) -> "BaseOperator":
+        """:meta private:"""
+        # Exists to make typing easier
+        raise TypeError("Internal code error: Do not call unmap on BaseOperator!")
 
 
 def _validate_kwarg_names_for_mapping(
@@ -1591,7 +1567,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
     # Needed for SerializedBaseOperator
     _is_dummy: bool = attr.ib()
 
-    deps: Iterable[BaseTIDep] = attr.ib()
+    deps: Iterable[BaseTIDep]
     operator_extra_links: Iterable['BaseOperatorLink'] = ()
     template_fields: Collection[str] = attr.ib()
     template_ext: Collection[str] = attr.ib()
@@ -1602,16 +1578,16 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
 
     subdag: None = attr.ib(init=False)
 
+    DEFAULT_DEPS: ClassVar[FrozenSet[BaseTIDep]] = frozenset(BaseOperator.deps) | frozenset(
+        [MappedTaskIsExpanded()]
+    )
+
     @_is_dummy.default
     def _is_dummy_from_operator_class(self):
         from airflow.operators.dummy import DummyOperator
 
         return issubclass(self.operator_class, DummyOperator)
 
-    @deps.default
-    def _deps_from_operator_class(self):
-        return self.operator_class.deps
-
     @template_fields.default
     def _template_fields_from_operator_class(self):
         return self.operator_class.template_fields
@@ -1648,7 +1624,8 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
 
     @classmethod
     def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> "MappedOperator":
-        dag: Optional["DAG"] = getattr(operator, '_dag', None)
+        dag = operator.get_dag()
+        task_group = operator.task_group
         if dag:
             # When BaseOperator() was called within a DAG, it would have been added straight away, but now we
             # are mapped, we want to _remove_ that task from the dag
@@ -1658,7 +1635,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
         return MappedOperator(
             operator_class=type(operator),
             task_id=operator.task_id,
-            task_group=operator.task_group,
+            task_group=task_group,
             dag=dag,
             upstream_task_ids=operator.upstream_task_ids,
             downstream_task_ids=operator.downstream_task_ids,
@@ -1668,7 +1645,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
             mapped_kwargs=mapped_kwargs,
             owner=operator.owner,
             max_active_tis_per_dag=operator.max_active_tis_per_dag,
-            deps=operator.deps,
+            deps=cls._deps(operator.deps),
         )
 
     @classmethod
@@ -1695,12 +1672,20 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
             task_id=task_id,
             dag=dag,
             task_group=task_group,
+            deps=cls._deps(decorator.operator_class.deps),
         )
         operator.mapped_kwargs.update(mapped_kwargs)
         for arg in mapped_kwargs.values():
             XComArg.apply_upstream_relationship(operator, arg)
         return operator
 
+    @classmethod
+    def _deps(cls, deps: Iterable[BaseTIDep]):
+        if deps is BaseOperator.deps:
+            return cls.DEFAULT_DEPS
+        else:
+            return frozenset(deps) | {MappedTaskIsExpanded()}
+
     def __attrs_post_init__(self):
         from airflow.models.xcom_arg import XComArg
 
@@ -1756,9 +1741,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
         """Used to determine if an Operator is inherited from DummyOperator"""
         return self._is_dummy
 
-    @property
-    def is_mapped(self) -> bool:
-        return True
+    is_mapped: ClassVar[bool] = True
 
     # The _serialized_fields are lazily loaded when get_serialized_fields() method is called
     __serialized_fields: ClassVar[Optional[FrozenSet[str]]] = None
@@ -1777,6 +1760,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
                     'operator_extra_links',
                     'upstream_task_ids',
                     'task_type',
+                    'task_group',
                     # These are automatically populated from partial_kwargs. In
                     # a perfect world, they should be properties like other
                     # partial_kwargs-populated values e.g. 'queue' below, but we
@@ -1826,8 +1810,14 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
     def depends_on_past(self) -> bool:
         return self.partial_kwargs.get("depends_on_past") or self.wait_for_downstream
 
-    def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = NEW_SESSION) -> None:
-        """Create the mapped TaskInstances for mapped task."""
+    def expand_mapped_task(
+        self, upstream_ti: "TaskInstance", session: "Session" = NEW_SESSION
+    ) -> Sequence[TaskInstance]:
+        """
+        Create the mapped TaskInstances for mapped task.
+
+        :return: The mapped TaskInstances
+        """
         # TODO: support having multiuple mapped upstreams?
         from airflow.models.taskmap import TaskMap
         from airflow.settings import task_instance_mutation_hook
@@ -1846,6 +1836,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
             # TODO: What would lead to this? How can this be better handled?
             raise RuntimeError("mapped operator cannot be expanded; upstream not found")
 
+        state = None
         unmapped_ti: Optional[TaskInstance] = (
             session.query(TaskInstance)
             .filter(
@@ -1858,6 +1849,8 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
             .one_or_none()
         )
 
+        ret: List[TaskInstance] = []
+
         if unmapped_ti:
             # The unmapped task instance still exists and is unfinished, i.e. we
             # haven't tried to run it before.
@@ -1867,20 +1860,34 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
                 self.log.info("Marking %s as SKIPPED since the map has 0 values to expand", unmapped_ti)
                 unmapped_ti.state = TaskInstanceState.SKIPPED
                 session.flush()
-                return
+                return ret
             # Otherwise convert this into the first mapped index, and create
             # TaskInstance for other indexes.
             unmapped_ti.map_index = 0
+            state = unmapped_ti.state
+            self.log.debug("Updated in place to become %s", unmapped_ti)
+            ret.append(unmapped_ti)
             indexes_to_map = range(1, task_map_info_length)
         else:
-            indexes_to_map = range(task_map_info_length)
+            # Only create "missing" ones.
+            current_max_mapping = (
+                session.query(func.max(TaskInstance.map_index))
+                .filter(
+                    TaskInstance.dag_id == upstream_ti.dag_id,
+                    TaskInstance.task_id == self.task_id,
+                    TaskInstance.run_id == upstream_ti.run_id,
+                )
+                .scalar()
+            )
+            indexes_to_map = range(current_max_mapping + 1, task_map_info_length)
 
         for index in indexes_to_map:
             # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
             # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator
-            ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index)  # type: ignore
+            ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index, state=state)  # type: ignore
+            self.log.debug("Expanding TIs upserted %s", ti)
             task_instance_mutation_hook(ti)
-            session.merge(ti)
+            ret.append(session.merge(ti))
 
         # Set to "REMOVED" any (old) TaskInstances with map indices greater
         # than the current map value
@@ -1893,6 +1900,25 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
 
         session.flush()
 
+        return ret
+
+    def unmap(self) -> BaseOperator:
+        """Get the "normal" Operator after applying the current mapping"""
+        assert not isinstance(self.operator_class, str)
+
+        dag = self.get_dag()
+        if not dag:
+            raise RuntimeError("Cannot unmapp a task unless it has a dag")
+
+        args = {
+            **self.partial_kwargs,
+            **self.mapped_kwargs,
+        }
+        dag._remove_task(self.task_id)
+        task = self.operator_class(task_id=self.task_id, dag=self.dag, **args)
+
+        return task
+
 
 # TODO: Deprecate for Airflow 3.0
 Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 0528dbb..7f151f4d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -303,20 +303,23 @@ class TaskInstanceKey(NamedTuple):
     task_id: str
     run_id: str
     try_number: int = 1
+    map_index: int = -1
 
     @property
-    def primary(self) -> Tuple[str, str, str]:
+    def primary(self) -> Tuple[str, str, str, int]:
         """Return task instance primary key part of the key"""
-        return self.dag_id, self.task_id, self.run_id
+        return self.dag_id, self.task_id, self.run_id, self.map_index
 
     @property
     def reduced(self) -> 'TaskInstanceKey':
         """Remake the key by subtracting 1 from try number to match in memory information"""
-        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1))
+        return TaskInstanceKey(
+            self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1), self.map_index
+        )
 
     def with_try_number(self, try_number: int) -> 'TaskInstanceKey':
         """Returns TaskInstanceKey with provided ``try_number``"""
-        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number)
+        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number, self.map_index)
 
     @property
     def key(self) -> "TaskInstanceKey":
@@ -795,8 +798,6 @@ class TaskInstance(Base, LoggingMixin):
         else:
             self.state = None
 
-        self.log.debug("Refreshed TaskInstance %s", self)
-
     def refresh_from_task(self, task: "BaseOperator", pool_override=None):
         """
         Copy common attributes from the given task.
@@ -829,12 +830,11 @@ class TaskInstance(Base, LoggingMixin):
             execution_date=self.execution_date,
             session=session,
         )
-        self.log.debug("XCom data cleared")
 
     @property
     def key(self) -> TaskInstanceKey:
         """Returns a tuple that identifies the task instance uniquely"""
-        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number)
+        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)
 
     @provide_session
     def set_state(self, state: Optional[str], session=NEW_SESSION):
@@ -1068,7 +1068,10 @@ class TaskInstance(Base, LoggingMixin):
                     yield dep_status
 
     def __repr__(self):
-        return f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} [{self.state}]>"
+        prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} "
+        if self.map_index != -1:
+            prefix += f"map_index={self.map_index} "
+        return prefix + f"[{self.state}]>"
 
     def next_retry_datetime(self):
         """
@@ -1312,6 +1315,11 @@ class TaskInstance(Base, LoggingMixin):
         :param pool: specifies the pool to use to run the task instance
         :param session: SQLAlchemy ORM Session
         """
+        if self.task.is_mapped:
+            raise RuntimeError(
+                f'task property of {self.task_id!r} was still a MappedOperator -- it should have been '
+                'expanded already!'
+            )
         self.test_mode = test_mode
         self.refresh_from_task(self.task, pool_override=pool)
         self.refresh_from_db(session=session)
@@ -1719,6 +1727,8 @@ class TaskInstance(Base, LoggingMixin):
             self.refresh_from_db(session)
 
         task = self.task
+        if task.is_mapped:
+            task = task.unmap()
         self.end_date = timezone.utcnow()
         self.set_duration()
         Stats.incr(f'operator_failures_{task.task_type}', 1, 1)
@@ -2252,19 +2262,29 @@ class TaskInstance(Base, LoggingMixin):
 
         dag_id = first.dag_id
         run_id = first.run_id
+        map_index = first.map_index
         first_task_id = first.task_id
         # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id
-        # and task_id -- this can be over 150x for huge numbers of TIs (20k+)
-        if all(t.dag_id == dag_id and t.run_id == run_id for t in tis):
+        # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+)
+        if all(t.dag_id == dag_id and t.run_id == run_id and t.map_index == map_index for t in tis):
             return and_(
                 TaskInstance.dag_id == dag_id,
                 TaskInstance.run_id == run_id,
+                TaskInstance.map_index == map_index,
                 TaskInstance.task_id.in_(t.task_id for t in tis),
             )
-        if all(t.dag_id == dag_id and t.task_id == first_task_id for t in tis):
+        if all(t.dag_id == dag_id and t.task_id == first_task_id and t.map_index == map_index for t in tis):
             return and_(
                 TaskInstance.dag_id == dag_id,
                 TaskInstance.run_id.in_(t.run_id for t in tis),
+                TaskInstance.map_index == map_index,
+                TaskInstance.task_id == first_task_id,
+            )
+        if all(t.dag_id == dag_id and t.run_id == run_id and t.task_id == first_task_id for t in tis):
+            return and_(
+                TaskInstance.dag_id == dag_id,
+                TaskInstance.run_id == run_id,
+                TaskInstance.map_index.in_(t.map_index for t in tis),
                 TaskInstance.task_id == first_task_id,
             )
 
@@ -2274,13 +2294,14 @@ class TaskInstance(Base, LoggingMixin):
                     TaskInstance.dag_id == ti.dag_id,
                     TaskInstance.task_id == ti.task_id,
                     TaskInstance.run_id == ti.run_id,
+                    TaskInstance.map_index == ti.map_index,
                 )
                 for ti in tis
             )
         else:
-            return tuple_(TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id).in_(
-                [ti.key.primary for ti in tis]
-            )
+            return tuple_(
+                TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index
+            ).in_([ti.key.primary for ti in tis])
 
 
 # State of the task instance.
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 4fc9566..7c06155 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -17,7 +17,7 @@
 
 import warnings
 from abc import ABCMeta, abstractmethod
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Set, Tuple, Union
+from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 import pendulum
 
@@ -109,6 +109,8 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
     """
 
     dag: Optional["DAG"] = None
+    task_group: Optional["TaskGroup"] = None
+    """The task_group that contains this node"""
 
     @property
     @abstractmethod
@@ -117,15 +119,12 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
 
     @property
     def label(self) -> Optional[str]:
-        tg: Optional["TaskGroup"] = getattr(self, 'task_group', None)
+        tg = self.task_group
         if tg and tg.node_id and tg.prefix_group_id:
             # "task_group_id.task_id" -> "task_id"
             return self.node_id[len(tg.node_id) + 1 :]
         return self.node_id
 
-    task_group: Optional["TaskGroup"]
-    """The task_group that contains this node"""
-
     start_date: Optional[pendulum.DateTime]
     end_date: Optional[pendulum.DateTime]
     upstream_task_ids: Set[str]
@@ -268,3 +267,46 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
     def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:
         """This is used by SerializedTaskGroup to serialize a task group's content."""
         raise NotImplementedError()
+
+    def mapped_dependants(self) -> Iterator["DAGNode"]:
+        """Return any mapped nodes that are direct dependencies of the current task
+
+        For now, this walks the entire DAG to find mapped nodes that has this
+        current task as an upstream. We cannot use ``downstream_list`` since it
+        only contains operators, not task groups. In the future, we should
+        provide a way to record an DAG node's all downstream nodes instead.
+        """
+        from airflow.models.baseoperator import MappedOperator
+        from airflow.utils.task_group import MappedTaskGroup, TaskGroup
+
+        def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]:
+            """Recursively walk children in a task group.
+
+            This yields all direct children (including both tasks and task
+            groups), and all children of any task groups.
+            """
+            for key, child in group.children.items():
+                yield key, child
+                if isinstance(child, TaskGroup):
+                    yield from _walk_group(child)
+
+        tg = self.task_group
+        if not tg:
+            raise RuntimeError("Cannot check for mapped_dependants when not attached to a DAG")
+        for key, child in _walk_group(tg):
+            if key == self.node_id:
+                continue
+            if not isinstance(child, (MappedOperator, MappedTaskGroup)):
+                continue
+            if self.node_id in child.upstream_task_ids:
+                yield child
+
+    def has_mapped_dependants(self) -> bool:
+        """Whether any downstream dependencies depend on this task for mapping.
+
+        For now, this walks the entire DAG to find mapped nodes that has this
+        current task as an upstream. We cannot use ``downstream_list`` since it
+        only contains operators, not task groups. In the future, we should
+        provide a way to record an DAG node's all downstream nodes instead.
+        """
+        return any(self.mapped_dependants())
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 63820ff..42fa314 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -19,6 +19,7 @@
 import datetime
 import enum
 import logging
+import weakref
 from dataclasses import dataclass
 from inspect import Parameter, signature
 from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Type, Union
@@ -567,7 +568,9 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
 
     @classmethod
     def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
-        serialize_op = cls._serialize_node(op)
+
+        stock_deps = op.deps is MappedOperator.DEFAULT_DEPS
+        serialize_op = cls._serialize_node(op, include_deps=not stock_deps)
         # It must be a class at this point for it to work, not a string
         assert isinstance(op.operator_class, type)
         serialize_op['_task_type'] = op.operator_class.__name__
@@ -577,10 +580,10 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
 
     @classmethod
     def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]:
-        return cls._serialize_node(op)
+        return cls._serialize_node(op, include_deps=op.deps is not BaseOperator.deps)
 
     @classmethod
-    def _serialize_node(cls, op: Union[BaseOperator, MappedOperator]) -> Dict[str, Any]:
+    def _serialize_node(cls, op: Union[BaseOperator, MappedOperator], include_deps: bool) -> Dict[str, Any]:
         """Serializes operator into a JSON object."""
         serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
         serialize_op['_task_type'] = type(op).__name__
@@ -594,8 +597,8 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
                 op.operator_extra_links
             )
 
-        if op.deps is not BaseOperator.deps:
-            # Are the deps different to BaseOperator, if so serialize the class names!
+        if include_deps:
+            # Are the deps different to "stock", if so serialize the class names!
             # For Airflow 2.0 expediency we _only_ allow built in Dep classes.
             # Fix this for 2.0.x or 2.1
             deps = []
@@ -641,7 +644,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
                 # These are all re-set later
                 partial_kwargs={},
                 mapped_kwargs={},
-                deps=tuple(),
+                deps=MappedOperator.DEFAULT_DEPS,
                 is_dummy=False,
                 template_fields=(),
                 template_ext=(),
@@ -1084,8 +1087,13 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization):
             for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
         }
         group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, **kwargs)
+
+        def set_ref(task: BaseOperator) -> BaseOperator:
+            task.task_group = weakref.proxy(group)
+            return task
+
         group.children = {
-            label: task_dict[val]  # type: ignore
+            label: set_ref(task_dict[val])  # type: ignore
             if _type == DAT.OP  # type: ignore
             else SerializedTaskGroup.deserialize_task_group(val, group, task_dict)
             for label, (_type, val) in encoded_group["children"].items()
diff --git a/tests/models/__init__.py b/airflow/ti_deps/deps/mapped_task_expanded.py
similarity index 60%
copy from tests/models/__init__.py
copy to airflow/ti_deps/deps/mapped_task_expanded.py
index 2d4a0d9..03cf07d 100644
--- a/tests/models/__init__.py
+++ b/airflow/ti_deps/deps/mapped_task_expanded.py
@@ -15,10 +15,18 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 
-import os
 
-from airflow.utils import timezone
+class MappedTaskIsExpanded(BaseTIDep):
+    """Checks that a mapped task has been expanded before it's TaskInstance can run."""
 
-DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-TEST_DAGS_FOLDER = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'dags'))
+    NAME = "Task has been mapped"
+    IGNORABLE = False
+    IS_TASK_DEP = False
+
+    def _get_dep_statuses(self, ti, session, dep_context):
+        if ti.map_index == -1:
+            yield self._failing_status(reason="The task has yet to be mapped!")
+            return
+        yield self._passing_status(reason="The task has been mapped")
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 8f193f4..88b956e 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -92,7 +92,6 @@ class TaskGroup(DAGNode):
             # used_group_ids is shared across all TaskGroups in the same DAG to keep track
             # of used group_id to avoid duplication.
             self.used_group_ids = set()
-            self._parent_group = None
             self.dag = dag
         else:
             if prefix_group_id:
@@ -108,28 +107,29 @@ class TaskGroup(DAGNode):
             if not parent_group and not dag:
                 raise AirflowException("TaskGroup can only be used inside a dag")
 
-            self._parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
-            if not self._parent_group:
+            parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
+            if not parent_group:
                 raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup")
-            if dag is not self._parent_group.dag:
+            if dag is not parent_group.dag:
                 raise RuntimeError(
-                    "Cannot mix TaskGroups from different DAGs: %s and %s", dag, self._parent_group.dag
+                    "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag
                 )
 
-            self.used_group_ids = self._parent_group.used_group_ids
+            self.used_group_ids = parent_group.used_group_ids
 
         # if given group_id already used assign suffix by incrementing largest used suffix integer
         # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
         self._group_id = group_id
         self._check_for_group_id_collisions(add_suffix_on_collision)
 
+        self.children: Dict[str, DAGNode] = {}
+        if parent_group:
+            parent_group.add(self)
+
         self.used_group_ids.add(self.group_id)
         if self.group_id:
             self.used_group_ids.add(self.downstream_join_id)
             self.used_group_ids.add(self.upstream_join_id)
-        self.children: Dict[str, DAGNode] = {}
-        if self._parent_group:
-            self._parent_group.add(self)
 
         self.tooltip = tooltip
         self.ui_color = ui_color
@@ -175,6 +175,10 @@ class TaskGroup(DAGNode):
         """Returns True if this TaskGroup is the root TaskGroup. Otherwise False"""
         return not self.group_id
 
+    @property
+    def parent_group(self) -> Optional["TaskGroup"]:
+        return self.task_group
+
     def __iter__(self):
         for child in self.children.values():
             if isinstance(child, TaskGroup):
@@ -184,6 +188,8 @@ class TaskGroup(DAGNode):
 
     def add(self, task: DAGNode) -> None:
         """Add a task to this TaskGroup."""
+        # Set the TG first, as setting it might change the return value of node_id!
+        task.task_group = weakref.proxy(self)
         key = task.node_id
 
         if key in self.children:
@@ -201,7 +207,6 @@ class TaskGroup(DAGNode):
                 raise AirflowException("Cannot add a non-empty TaskGroup")
 
         self.children[key] = task
-        task.task_group = weakref.proxy(self)
 
     def _remove(self, task: DAGNode) -> None:
         key = task.node_id
@@ -216,8 +221,8 @@ class TaskGroup(DAGNode):
     @property
     def group_id(self) -> Optional[str]:
         """group_id of this TaskGroup."""
-        if self._parent_group and self._parent_group.prefix_group_id and self._parent_group.group_id:
-            return self._parent_group.child_id(self._group_id)
+        if self.task_group and self.task_group.prefix_group_id and self.task_group.group_id:
+            return self.task_group.child_id(self._group_id)
 
         return self._group_id
 
@@ -380,8 +385,8 @@ class TaskGroup(DAGNode):
             raise RuntimeError("Cannot map a TaskGroup that already has children")
         if not self.group_id:
             raise RuntimeError("Cannot map a TaskGroup before it has a group_id")
-        if self._parent_group:
-            self._parent_group._remove(self)
+        if self.task_group:
+            self.task_group._remove(self)
         return MappedTaskGroup(group_id=self._group_id, dag=self.dag, mapped_arg=arg)
 
 
diff --git a/tests/models/__init__.py b/tests/dags/test_mapped_classic.py
similarity index 65%
copy from tests/models/__init__.py
copy to tests/dags/test_mapped_classic.py
index 2d4a0d9..14f2c1f 100644
--- a/tests/models/__init__.py
+++ b/tests/dags/test_mapped_classic.py
@@ -1,4 +1,3 @@
-#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -16,9 +15,20 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import os
+from airflow import DAG
+from airflow.decorators import task
+from airflow.operators.python import PythonOperator
+from airflow.utils.dates import days_ago
+
+
+@task
+def make_list():
+    return [1, 2, {'a': 'b'}]
+
+
+def consumer(*args):
+    print(repr(args))
 
-from airflow.utils import timezone
 
-DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-TEST_DAGS_FOLDER = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'dags'))
+with DAG(dag_id='test_mapped_classic', start_date=days_ago(2)) as dag:
+    PythonOperator(task_id='consumer', python_callable=consumer).map(op_args=make_list())
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index 340e67b..af79b16 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -29,6 +29,7 @@ from kubernetes.client.rest import ApiException
 from urllib3 import HTTPResponse
 
 from airflow import AirflowException
+from airflow.models.taskinstance import TaskInstanceKey
 from airflow.utils import timezone
 from tests.test_utils.config import conf_vars
 
@@ -244,7 +245,7 @@ class TestKubernetesExecutor:
             kubernetes_executor.start()
             # Execute a task while the Api Throws errors
             try_number = 1
-            task_instance_key = ('dag', 'task', 'run_id', try_number)
+            task_instance_key = TaskInstanceKey('dag', 'task', 'run_id', try_number)
             kubernetes_executor.execute_async(
                 key=task_instance_key,
                 queue=None,
@@ -326,7 +327,7 @@ class TestKubernetesExecutor:
             assert executor.task_queue.empty()
 
             executor.execute_async(
-                key=('dag', 'task', 'run_id', 1),
+                key=TaskInstanceKey('dag', 'task', 'run_id', 1),
                 queue=None,
                 command=['airflow', 'tasks', 'run', 'true', 'some_parameter'],
                 executor_config={
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index 37b5acc..0878f63 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -41,10 +41,12 @@ from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstanceKey
 from airflow.operators.dummy import DummyOperator
 from airflow.utils import timezone
+from airflow.utils.dates import days_ago
 from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
 from airflow.utils.types import DagRunType
+from tests.models import TEST_DAGS_FOLDER
 from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots
 from tests.test_utils.mock_executor import MockExecutor
 from tests.test_utils.timetables import cron_timetable
@@ -190,7 +192,7 @@ class TestBackfillJob:
             ("run_this_last", end_date),
         ]
         assert [
-            ((dag.dag_id, task_id, f'backfill__{when.isoformat()}', 1), (State.SUCCESS, None))
+            ((dag.dag_id, task_id, f'backfill__{when.isoformat()}', 1, -1), (State.SUCCESS, None))
             for (task_id, when) in expected_execution_order
         ] == executor.sorted_tasks
 
@@ -267,7 +269,7 @@ class TestBackfillJob:
 
         job.run()
         assert [
-            ((dag_id, task_id, f'backfill__{DEFAULT_DATE.isoformat()}', 1), (State.SUCCESS, None))
+            ((dag_id, task_id, f'backfill__{DEFAULT_DATE.isoformat()}', 1, -1), (State.SUCCESS, None))
             for task_id in expected_execution_order
         ] == executor.sorted_tasks
 
@@ -1230,12 +1232,11 @@ class TestBackfillJob:
         subdag.clear()
         dag.clear()
 
-    def test_update_counters(self, dag_maker):
-        with dag_maker(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE) as dag:
+    def test_update_counters(self, dag_maker, session):
+        with dag_maker(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE, session=session) as dag:
             task1 = DummyOperator(task_id='dummy', owner='airflow')
         dr = dag_maker.create_dagrun()
         job = BackfillJob(dag=dag)
-        session = settings.Session()
 
         ti = TI(task1, dr.execution_date)
         ti.refresh_from_db()
@@ -1245,7 +1246,7 @@ class TestBackfillJob:
         # test for success
         ti.set_state(State.SUCCESS, session)
         ti_status.running[ti.key] = ti
-        job._update_counters(ti_status=ti_status)
+        job._update_counters(ti_status=ti_status, session=session)
         assert len(ti_status.running) == 0
         assert len(ti_status.succeeded) == 1
         assert len(ti_status.skipped) == 0
@@ -1257,7 +1258,7 @@ class TestBackfillJob:
         # test for skipped
         ti.set_state(State.SKIPPED, session)
         ti_status.running[ti.key] = ti
-        job._update_counters(ti_status=ti_status)
+        job._update_counters(ti_status=ti_status, session=session)
         assert len(ti_status.running) == 0
         assert len(ti_status.succeeded) == 0
         assert len(ti_status.skipped) == 1
@@ -1269,7 +1270,7 @@ class TestBackfillJob:
         # test for failed
         ti.set_state(State.FAILED, session)
         ti_status.running[ti.key] = ti
-        job._update_counters(ti_status=ti_status)
+        job._update_counters(ti_status=ti_status, session=session)
         assert len(ti_status.running) == 0
         assert len(ti_status.succeeded) == 0
         assert len(ti_status.skipped) == 0
@@ -1281,7 +1282,7 @@ class TestBackfillJob:
         # test for retry
         ti.set_state(State.UP_FOR_RETRY, session)
         ti_status.running[ti.key] = ti
-        job._update_counters(ti_status=ti_status)
+        job._update_counters(ti_status=ti_status, session=session)
         assert len(ti_status.running) == 0
         assert len(ti_status.succeeded) == 0
         assert len(ti_status.skipped) == 0
@@ -1297,7 +1298,7 @@ class TestBackfillJob:
         ti.set_state(State.UP_FOR_RESCHEDULE, session)
         assert ti.try_number == 3  # see ti.try_number property in taskinstance module
         ti_status.running[ti.key] = ti
-        job._update_counters(ti_status=ti_status)
+        job._update_counters(ti_status=ti_status, session=session)
         assert len(ti_status.running) == 0
         assert len(ti_status.succeeded) == 0
         assert len(ti_status.skipped) == 0
@@ -1315,7 +1316,7 @@ class TestBackfillJob:
         session.merge(ti)
         session.commit()
         ti_status.running[ti.key] = ti
-        job._update_counters(ti_status=ti_status)
+        job._update_counters(ti_status=ti_status, session=session)
         assert len(ti_status.running) == 0
         assert len(ti_status.succeeded) == 0
         assert len(ti_status.skipped) == 0
@@ -1510,3 +1511,22 @@ class TestBackfillJob:
         )
         job.run()
         assert executor.job_id is not None
+
+    def test_mapped_dag(self, dag_maker):
+        """End-to-end test of a simple mapped dag"""
+        # Use SequentialExecutor for more predictable test behaviour
+        from airflow.executors.sequential_executor import SequentialExecutor
+
+        self.dagbag.process_file(str(TEST_DAGS_FOLDER / 'test_mapped_classic.py'))
+        dag = self.dagbag.get_dag('test_mapped_classic')
+
+        # This needs a real executor to run, so that the `make_list` task can write out the TaskMap
+
+        job = BackfillJob(
+            dag=dag,
+            start_date=days_ago(1),
+            end_date=days_ago(1),
+            donot_pickle=True,
+            executor=SequentialExecutor(),
+        )
+        job.run()
diff --git a/tests/models/__init__.py b/tests/models/__init__.py
index 2d4a0d9..c1cbabd 100644
--- a/tests/models/__init__.py
+++ b/tests/models/__init__.py
@@ -16,9 +16,9 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import os
+import pathlib
 
 from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-TEST_DAGS_FOLDER = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'dags'))
+TEST_DAGS_FOLDER = pathlib.Path(__file__).parent.with_name('dags')
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index ffde6ee..9f3e8ad 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -718,6 +718,7 @@ def test_task_mapping_with_dag():
 
     assert task1.downstream_list == [mapped]
     assert mapped in dag.tasks
+    assert mapped.task_group == dag.task_group
     # At parse time there should only be three tasks!
     assert len(dag.tasks) == 3
 
@@ -799,11 +800,21 @@ def test_partial_on_class_invalid_ctor_args() -> None:
     ["num_existing_tis", "expected"],
     (
         pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'),
-        pytest.param(3, [(0, None), (1, None), (2, None)], id='all-tis-exist'),
+        pytest.param(
+            3,
+            [(0, 'success'), (1, 'success'), (2, 'success')],
+            id='all-tis-exist',
+        ),
         pytest.param(
             5,
-            [(0, None), (1, None), (2, None), (3, TaskInstanceState.REMOVED), (4, TaskInstanceState.REMOVED)],
-            id="tis-to-be-remove",
+            [
+                (0, 'success'),
+                (1, 'success'),
+                (2, 'success'),
+                (3, TaskInstanceState.REMOVED),
+                (4, TaskInstanceState.REMOVED),
+            ],
+            id="tis-to-be-removed",
         ),
     ),
 )
@@ -836,7 +847,8 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec
         ).delete()
 
     for index in range(num_existing_tis):
-        ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index)  # type: ignore
+        # Give the existing TIs a state to make sure we don't change them
+        ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS)
         session.add(ti)
     session.flush()
 
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index f7e40ef..1eb3328 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -2255,7 +2255,7 @@ def test_set_task_instance_state(run_id, execution_date, session, dag_maker):
     # dagrun should be set to QUEUED
     assert dagrun.get_state() == State.QUEUED
 
-    assert {t.key for t in altered} == {('test_set_task_instance_state', 'task_1', dagrun.run_id, 1)}
+    assert {t.key for t in altered} == {('test_set_task_instance_state', 'task_1', dagrun.run_id, 1, -1)}
 
 
 @pytest.mark.parametrize(
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index ff4207a..bbf05c3 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1947,7 +1947,7 @@ class TestTaskInstance:
 
         with dag_maker('test-dag', session=session) as dag:
             task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
-        dag.fileloc = TEST_DAGS_FOLDER + '/test_get_k8s_pod_yaml.py'
+        dag.fileloc = TEST_DAGS_FOLDER / 'test_get_k8s_pod_yaml.py'
         ti = dag_maker.create_dagrun().task_instances[0]
         ti.task = task
 
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 35d0d68..447b173 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -488,6 +488,9 @@ class TestStringifiedDAGs:
         assert not isinstance(task, SerializedBaseOperator)
         assert isinstance(task, BaseOperator)
 
+        # Every task should have a task_group property -- even if it's the DAG's root task group
+        assert serialized_task.task_group
+
         fields_to_check = task.get_serialized_fields() - {
             # Checked separately
             '_task_type',
@@ -1608,6 +1611,7 @@ def test_mapped_operator_serde():
 
     op = SerializedBaseOperator.deserialize_operator(serialized)
     assert isinstance(op, MappedOperator)
+    assert op.deps is MappedOperator.DEFAULT_DEPS
 
     assert op.operator_class == "airflow.operators.bash.BashOperator"
     assert op.mapped_kwargs['bash_command'] == literal
@@ -1637,6 +1641,7 @@ def test_mapped_operator_xcomarg_serde():
     }
 
     op = SerializedBaseOperator.deserialize_operator(serialized)
+    assert op.deps is MappedOperator.DEFAULT_DEPS
 
     arg = op.mapped_kwargs['arg2']
     assert arg.task_id == 'op1'
diff --git a/tests/test_utils/mock_executor.py b/tests/test_utils/mock_executor.py
index 37f49cf..23d32b6 100644
--- a/tests/test_utils/mock_executor.py
+++ b/tests/test_utils/mock_executor.py
@@ -59,10 +59,10 @@ class MockExecutor(BaseExecutor):
             # for tests!
             def sort_by(item):
                 key, val = item
-                (dag_id, task_id, date, try_number) = key
+                (dag_id, task_id, date, try_number, map_index) = key
                 (_, prio, _, _) = val
                 # Sort by priority (DESC), then date,task, try
-                return -prio, date, dag_id, task_id, try_number
+                return -prio, date, dag_id, task_id, map_index, try_number
 
             open_slots = self.parallelism - len(self.running)
             sorted_queue = sorted(self.queued_tasks.items(), key=sort_by)