You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/02/04 14:25:13 UTC
[airflow] branch main updated: Make `airflow dags test` be able to execute Mapped Tasks (#21210)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6fc6edf Make `airflow dags test` be able to execute Mapped Tasks (#21210)
6fc6edf is described below
commit 6fc6edf6af7f676bfa54ff3a2e6e6d2edb938f2e
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Fri Feb 4 14:24:32 2022 +0000
Make `airflow dags test` be able to execute Mapped Tasks (#21210)
* Make `airflow dags test` be able to execute Mapped Tasks
In order to do this there were two steps required:
- The BackfillJob needs to know about mapped tasks, both to expand them,
and in order to update it's TI tracking
- The DebugExecutor needed to "unmap" the mapped task to get the real
operator back
I was testing this with the following dag:
```
from airflow import DAG
from airflow.decorators import task
from airflow.operators.python import PythonOperator
import pendulum
@task
def make_list():
return list(map(lambda a: f'echo "{a!r}"', [1, 2, {'a': 'b'}]))
def consumer(*args):
print(repr(args))
with DAG(dag_id='maptest', start_date=pendulum.DateTime(2022, 1, 18)) as dag:
PythonOperator(task_id='consumer', python_callable=consumer).map(op_args=make_list())
```
It can't "unmap" decorated operators successfully yet, so we're using
old-school PythonOperator
We also just pass the whole value to the operator, not just the current
mapping value(s)
* Always have a `task_group` property on DAGNodes
And since TaskGroup is a DAGNode, we don't need to store parent group
directly anymore -- it'll already be stored
* Add "integation" tests for running mapped tasks via BackfillJob
* Only show "Map Index" in Backfill report when relevant
Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
airflow/cli/commands/task_command.py | 2 +
airflow/executors/debug_executor.py | 2 +
airflow/executors/kubernetes_executor.py | 2 +-
airflow/jobs/backfill_job.py | 117 ++++++++++--------
airflow/jobs/local_task_job.py | 6 +
airflow/jobs/scheduler_job.py | 2 +-
airflow/models/baseoperator.py | 134 ++++++++++++---------
airflow/models/taskinstance.py | 51 +++++---
airflow/models/taskmixin.py | 52 +++++++-
airflow/serialization/serialized_objects.py | 22 ++--
.../ti_deps/deps/mapped_task_expanded.py | 16 ++-
airflow/utils/task_group.py | 33 ++---
.../__init__.py => dags/test_mapped_classic.py} | 20 ++-
tests/executors/test_kubernetes_executor.py | 5 +-
tests/jobs/test_backfill_job.py | 42 +++++--
tests/models/__init__.py | 4 +-
tests/models/test_baseoperator.py | 20 ++-
tests/models/test_dag.py | 2 +-
tests/models/test_taskinstance.py | 2 +-
tests/serialization/test_dag_serialization.py | 5 +
tests/test_utils/mock_executor.py | 4 +-
21 files changed, 366 insertions(+), 177 deletions(-)
diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index 537fab0..1b5208f 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -224,6 +224,8 @@ RAW_TASK_UNSUPPORTED_OPTION = [
def _run_raw_task(args, ti: TaskInstance) -> None:
"""Runs the main task handling code"""
+ if ti.task.is_mapped:
+ ti.task = ti.task.unmap()
ti._run_raw_task(
mark_success=args.mark_success,
job_id=args.job_id,
diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py
index 865186d..0ab5f35 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -76,6 +76,8 @@ class DebugExecutor(BaseExecutor):
key = ti.key
try:
params = self.tasks_params.pop(ti.key, {})
+ if ti.task.is_mapped:
+ ti.task = ti.task.unmap()
ti._run_raw_task(job_id=ti.job_id, **params)
self.change_state(key, State.SUCCESS)
ti._run_finished_callback()
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index 1071a3a..ef671eb 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -296,7 +296,7 @@ class AirflowKubernetesScheduler(LoggingMixin):
"""
self.log.info('Kubernetes job is %s', str(next_job).replace("\n", " "))
key, command, kube_executor_config, pod_template_file = next_job
- dag_id, task_id, run_id, try_number = key
+ dag_id, task_id, run_id, try_number, _ = key
if command[0:3] != ["airflow", "tasks", "run"]:
raise ValueError('The command must start with ["airflow", "tasks", "run"].')
diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index 406c2ea..10a5d08 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -18,9 +18,9 @@
#
import time
-from collections import OrderedDict
-from typing import Optional, Set
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple
+import attr
import pendulum
from sqlalchemy.orm.session import Session, make_transient
from tabulate import tabulate
@@ -48,6 +48,9 @@ from airflow.utils.session import provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType
+if TYPE_CHECKING:
+ from airflow.models.baseoperator import MappedOperator
+
class BackfillJob(BaseJob):
"""
@@ -60,6 +63,7 @@ class BackfillJob(BaseJob):
__mapper_args__ = {'polymorphic_identity': 'BackfillJob'}
+ @attr.define
class _DagRunTaskStatus:
"""
Internal status of the backfill job. This class is intended to be instantiated
@@ -83,32 +87,17 @@ class BackfillJob(BaseJob):
:param total_runs: Number of total dag runs able to run
"""
- # TODO(edgarRd): AIRFLOW-1444: Add consistency check on counts
- def __init__(
- self,
- to_run=None,
- running=None,
- skipped=None,
- succeeded=None,
- failed=None,
- not_ready=None,
- deadlocked=None,
- active_runs=None,
- executed_dag_run_dates=None,
- finished_runs=0,
- total_runs=0,
- ):
- self.to_run = to_run or OrderedDict()
- self.running = running or {}
- self.skipped = skipped or set()
- self.succeeded = succeeded or set()
- self.failed = failed or set()
- self.not_ready = not_ready or set()
- self.deadlocked = deadlocked or set()
- self.active_runs = active_runs or []
- self.executed_dag_run_dates = executed_dag_run_dates or set()
- self.finished_runs = finished_runs
- self.total_runs = total_runs
+ to_run: Dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict)
+ running: Dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict)
+ skipped: Set[TaskInstanceKey] = attr.ib(factory=set)
+ succeeded: Set[TaskInstanceKey] = attr.ib(factory=set)
+ failed: Set[TaskInstanceKey] = attr.ib(factory=set)
+ not_ready: Set[TaskInstanceKey] = attr.ib(factory=set)
+ deadlocked: Set[TaskInstance] = attr.ib(factory=set)
+ active_runs: List[DagRun] = attr.ib(factory=list)
+ executed_dag_run_dates: Set[pendulum.DateTime] = attr.ib(factory=set)
+ finished_runs: int = 0
+ total_runs: int = 0
def __init__(
self,
@@ -167,7 +156,6 @@ class BackfillJob(BaseJob):
self.run_at_least_once = run_at_least_once
super().__init__(*args, **kwargs)
- @provide_session
def _update_counters(self, ti_status, session=None):
"""
Updates the counters per state of the tasks that were running. Can re-add
@@ -234,14 +222,22 @@ class BackfillJob(BaseJob):
session.query(TI).filter(filter_for_tis).update(
values={TI.state: TaskInstanceState.SCHEDULED}, synchronize_session=False
)
+ session.flush()
- def _manage_executor_state(self, running):
+ def _manage_executor_state(
+ self, running, session
+ ) -> Iterator[Tuple["MappedOperator", str, Sequence[TaskInstance]]]:
"""
Checks if the executor agrees with the state of task instances
- that are running
+ that are running.
+
+ Expands downstream mapped tasks when necessary
:param running: dict of key, task to verify
+ :return: An iterable of expanded TaskInstance per MappedTask
"""
+ from airflow.models.baseoperator import MappedOperator
+
executor = self.executor
# TODO: query all instead of refresh from db
@@ -266,6 +262,11 @@ class BackfillJob(BaseJob):
)
self.log.error(msg)
ti.handle_failure_with_callback(error=msg)
+ continue
+ if ti.state not in self.STATES_COUNT_AS_RUNNING:
+ for node in ti.task.mapped_dependants():
+ assert isinstance(node, MappedOperator)
+ yield node, ti.run_id, node.expand_mapped_task(ti, session)
@provide_session
def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = None):
@@ -409,7 +410,6 @@ class BackfillJob(BaseJob):
# or leaf to root, as otherwise tasks might be
# determined deadlocked while they are actually
# waiting for their upstream to finish
- @provide_session
def _per_task_process(key, ti: TaskInstance, session=None):
ti.refresh_from_db(lock_for_update=True, session=session)
@@ -577,7 +577,8 @@ class BackfillJob(BaseJob):
"Not scheduling since Task concurrency limit is reached."
)
- _per_task_process(key, ti)
+ _per_task_process(key, ti, session)
+ session.commit()
except (NoAvailablePoolSlot, DagConcurrencyLimitReached, TaskConcurrencyLimitReached) as e:
self.log.debug(e)
@@ -597,11 +598,23 @@ class BackfillJob(BaseJob):
ti_status.deadlocked.update(ti_status.to_run.values())
ti_status.to_run.clear()
- # check executor state
- self._manage_executor_state(ti_status.running)
+ # check executor state -- and expand any mapped TIs
+ for node, run_id, mapped_tis in self._manage_executor_state(ti_status.running, session):
+
+ def to_keep(key: TaskInstanceKey) -> bool:
+ if key.dag_id != node.dag_id or key.task_id != node.task_id or key.run_id != run_id:
+ # For another Dag/Task/Run -- don't remove
+ return True
+ return False
+
+ # remove the old unmapped TIs for node -- they have been replaced with the mapped TIs
+ ti_status.to_run = {key: ti for (key, ti) in ti_status.to_run.items() if to_keep(key)}
+
+ ti_status.to_run.update({ti.key: ti for ti in mapped_tis})
# update the task counters
- self._update_counters(ti_status=ti_status)
+ self._update_counters(ti_status=ti_status, session=session)
+ session.commit()
# update dag run state
_dag_runs = ti_status.active_runs[:]
@@ -613,25 +626,33 @@ class BackfillJob(BaseJob):
executed_run_dates.append(run.execution_date)
self._log_progress(ti_status)
+ session.commit()
# return updated status
return executed_run_dates
@provide_session
- def _collect_errors(self, ti_status, session=None):
- def tabulate_ti_keys_set(set_ti_keys: Set[TaskInstanceKey]) -> str:
+ def _collect_errors(self, ti_status: _DagRunTaskStatus, session=None):
+ def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str:
# Sorting by execution date first
- sorted_ti_keys = sorted(
- set_ti_keys,
- key=lambda ti_key: (ti_key.run_id, ti_key.dag_id, ti_key.task_id, ti_key.try_number),
+ sorted_ti_keys: Any = sorted(
+ ti_keys,
+ key=lambda ti_key: (
+ ti_key.run_id,
+ ti_key.dag_id,
+ ti_key.task_id,
+ ti_key.map_index,
+ ti_key.try_number,
+ ),
)
- return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Run ID", "Try number"])
- def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str:
- # Sorting by execution date first
- sorted_tis = sorted(set_tis, key=lambda ti: (ti.run_id, ti.dag_id, ti.task_id, ti.try_number))
- tis_values = ((ti.dag_id, ti.task_id, ti.run_id, ti.try_number) for ti in sorted_tis)
- return tabulate(tis_values, headers=["DAG ID", "Task ID", "Run ID", "Try number"])
+ if all(key.map_index == -1 for key in ti_keys):
+ headers = ["DAG ID", "Task ID", "Run ID", "Try number"]
+ sorted_ti_keys = map(lambda k: k[0:4], sorted_ti_keys)
+ else:
+ headers = ["DAG ID", "Task ID", "Run ID", "Map Index", "Try number"]
+
+ return tabulate(sorted_ti_keys, headers=headers)
err = ''
if ti_status.failed:
@@ -667,7 +688,7 @@ class BackfillJob(BaseJob):
err += '\n\nThese tasks are skipped:\n'
err += tabulate_ti_keys_set(ti_status.skipped)
err += '\n\nThese tasks are deadlocked:\n'
- err += tabulate_tis_set(ti_status.deadlocked)
+ err += tabulate_ti_keys_set([ti.key for ti in ti_status.deadlocked])
return err
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index c0255d7..05ee533 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -104,6 +104,12 @@ class LocalTaskJob(BaseJob):
try:
self.task_runner.start()
+ # Unmap the task _after_ it has forked/execed. (This is a bit of a kludge, but if we unmap before
+ # fork, then the "run_raw_task" command will see the mapping index and an Non-mapped task and
+ # fail)
+ if self.task_instance.task.is_mapped:
+ self.task_instance.task = self.task_instance.task.unmap()
+
heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
# task callback invocation happens either here or in
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index cbda16e..7a6e3ef 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -534,7 +534,7 @@ class SchedulerJob(BaseJob):
"""Respond to executor events."""
if not self.processor_agent:
raise ValueError("Processor agent is not started.")
- ti_primary_key_to_try_number_map: Dict[Tuple[str, str, str], int] = {}
+ ti_primary_key_to_try_number_map: Dict[Tuple[str, str, str, int], int] = {}
event_buffer = self.executor.get_event_buffer()
tis_with_right_state: List[TaskInstanceKey] = []
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index d51dda8..34c8412 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -50,7 +50,7 @@ import attr
import jinja2
import pendulum
from dateutil.relativedelta import relativedelta
-from sqlalchemy import or_
+from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import NoResultFound
@@ -66,6 +66,7 @@ from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.serialization.enums import DagAttributeTypes
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
+from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
@@ -247,7 +248,12 @@ class BaseOperatorMeta(abc.ABCMeta):
# Validate that the args we passed are known -- at call/DAG parse time, not run time!
_validate_kwarg_names_for_mapping(operator_class, "partial", kwargs)
return MappedOperator(
- task_id=task_id, operator_class=operator_class, dag=dag, partial_kwargs=kwargs, mapped_kwargs={}
+ task_id=task_id,
+ operator_class=operator_class,
+ dag=dag,
+ partial_kwargs=kwargs,
+ mapped_kwargs={},
+ deps=MappedOperator._deps(operator_class.deps),
)
@@ -1459,9 +1465,7 @@ class BaseOperator(Operator, LoggingMixin, DAGNode, metaclass=BaseOperatorMeta):
"""Return if this operator can use smart service. Default False."""
return False
- @property
- def is_mapped(self) -> bool:
- return False
+ is_mapped: ClassVar[bool] = False
@property
def inherits_from_dummy_operator(self):
@@ -1491,38 +1495,10 @@ class BaseOperator(Operator, LoggingMixin, DAGNode, metaclass=BaseOperatorMeta):
def map(self, **kwargs) -> "MappedOperator":
return MappedOperator.from_operator(self, kwargs)
- def has_mapped_dependants(self) -> bool:
- """Whether any downstream dependencies depend on this task for mapping.
-
- For now, this walks the entire DAG to find mapped nodes that has this
- current task as an upstream. We cannot use ``downstream_list`` since it
- only contains operators, not task groups. In the future, we should
- provide a way to record an DAG node's all downstream nodes instead.
- """
- from airflow.utils.task_group import MappedTaskGroup, TaskGroup
-
- if not self.has_dag():
- return False
-
- def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]:
- """Recursively walk children in a task group.
-
- This yields all direct children (including both tasks and task
- groups), and all children of any task groups.
- """
- for key, child in group.children.items():
- yield key, child
- if isinstance(child, TaskGroup):
- yield from _walk_group(child)
-
- for key, child in _walk_group(self.dag.task_group):
- if key == self.task_id:
- continue
- if not isinstance(child, (MappedOperator, MappedTaskGroup)):
- continue
- if self.task_id in child.upstream_task_ids:
- return True
- return False
+ def unmap(self) -> "BaseOperator":
+ """:meta private:"""
+ # Exists to make typing easier
+ raise TypeError("Internal code error: Do not call unmap on BaseOperator!")
def _validate_kwarg_names_for_mapping(
@@ -1591,7 +1567,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
# Needed for SerializedBaseOperator
_is_dummy: bool = attr.ib()
- deps: Iterable[BaseTIDep] = attr.ib()
+ deps: Iterable[BaseTIDep]
operator_extra_links: Iterable['BaseOperatorLink'] = ()
template_fields: Collection[str] = attr.ib()
template_ext: Collection[str] = attr.ib()
@@ -1602,16 +1578,16 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
subdag: None = attr.ib(init=False)
+ DEFAULT_DEPS: ClassVar[FrozenSet[BaseTIDep]] = frozenset(BaseOperator.deps) | frozenset(
+ [MappedTaskIsExpanded()]
+ )
+
@_is_dummy.default
def _is_dummy_from_operator_class(self):
from airflow.operators.dummy import DummyOperator
return issubclass(self.operator_class, DummyOperator)
- @deps.default
- def _deps_from_operator_class(self):
- return self.operator_class.deps
-
@template_fields.default
def _template_fields_from_operator_class(self):
return self.operator_class.template_fields
@@ -1648,7 +1624,8 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
@classmethod
def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> "MappedOperator":
- dag: Optional["DAG"] = getattr(operator, '_dag', None)
+ dag = operator.get_dag()
+ task_group = operator.task_group
if dag:
# When BaseOperator() was called within a DAG, it would have been added straight away, but now we
# are mapped, we want to _remove_ that task from the dag
@@ -1658,7 +1635,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
return MappedOperator(
operator_class=type(operator),
task_id=operator.task_id,
- task_group=operator.task_group,
+ task_group=task_group,
dag=dag,
upstream_task_ids=operator.upstream_task_ids,
downstream_task_ids=operator.downstream_task_ids,
@@ -1668,7 +1645,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
mapped_kwargs=mapped_kwargs,
owner=operator.owner,
max_active_tis_per_dag=operator.max_active_tis_per_dag,
- deps=operator.deps,
+ deps=cls._deps(operator.deps),
)
@classmethod
@@ -1695,12 +1672,20 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
task_id=task_id,
dag=dag,
task_group=task_group,
+ deps=cls._deps(decorator.operator_class.deps),
)
operator.mapped_kwargs.update(mapped_kwargs)
for arg in mapped_kwargs.values():
XComArg.apply_upstream_relationship(operator, arg)
return operator
+ @classmethod
+ def _deps(cls, deps: Iterable[BaseTIDep]):
+ if deps is BaseOperator.deps:
+ return cls.DEFAULT_DEPS
+ else:
+ return frozenset(deps) | {MappedTaskIsExpanded()}
+
def __attrs_post_init__(self):
from airflow.models.xcom_arg import XComArg
@@ -1756,9 +1741,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
"""Used to determine if an Operator is inherited from DummyOperator"""
return self._is_dummy
- @property
- def is_mapped(self) -> bool:
- return True
+ is_mapped: ClassVar[bool] = True
# The _serialized_fields are lazily loaded when get_serialized_fields() method is called
__serialized_fields: ClassVar[Optional[FrozenSet[str]]] = None
@@ -1777,6 +1760,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
'operator_extra_links',
'upstream_task_ids',
'task_type',
+ 'task_group',
# These are automatically populated from partial_kwargs. In
# a perfect world, they should be properties like other
# partial_kwargs-populated values e.g. 'queue' below, but we
@@ -1826,8 +1810,14 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
def depends_on_past(self) -> bool:
return self.partial_kwargs.get("depends_on_past") or self.wait_for_downstream
- def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = NEW_SESSION) -> None:
- """Create the mapped TaskInstances for mapped task."""
+ def expand_mapped_task(
+ self, upstream_ti: "TaskInstance", session: "Session" = NEW_SESSION
+ ) -> Sequence[TaskInstance]:
+ """
+ Create the mapped TaskInstances for mapped task.
+
+ :return: The mapped TaskInstances
+ """
# TODO: support having multiuple mapped upstreams?
from airflow.models.taskmap import TaskMap
from airflow.settings import task_instance_mutation_hook
@@ -1846,6 +1836,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
# TODO: What would lead to this? How can this be better handled?
raise RuntimeError("mapped operator cannot be expanded; upstream not found")
+ state = None
unmapped_ti: Optional[TaskInstance] = (
session.query(TaskInstance)
.filter(
@@ -1858,6 +1849,8 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
.one_or_none()
)
+ ret: List[TaskInstance] = []
+
if unmapped_ti:
# The unmapped task instance still exists and is unfinished, i.e. we
# haven't tried to run it before.
@@ -1867,20 +1860,34 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
self.log.info("Marking %s as SKIPPED since the map has 0 values to expand", unmapped_ti)
unmapped_ti.state = TaskInstanceState.SKIPPED
session.flush()
- return
+ return ret
# Otherwise convert this into the first mapped index, and create
# TaskInstance for other indexes.
unmapped_ti.map_index = 0
+ state = unmapped_ti.state
+ self.log.debug("Updated in place to become %s", unmapped_ti)
+ ret.append(unmapped_ti)
indexes_to_map = range(1, task_map_info_length)
else:
- indexes_to_map = range(task_map_info_length)
+ # Only create "missing" ones.
+ current_max_mapping = (
+ session.query(func.max(TaskInstance.map_index))
+ .filter(
+ TaskInstance.dag_id == upstream_ti.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.run_id == upstream_ti.run_id,
+ )
+ .scalar()
+ )
+ indexes_to_map = range(current_max_mapping + 1, task_map_info_length)
for index in indexes_to_map:
# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
# TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator
- ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index) # type: ignore
+ ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index, state=state) # type: ignore
+ self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
- session.merge(ti)
+ ret.append(session.merge(ti))
# Set to "REMOVED" any (old) TaskInstances with map indices greater
# than the current map value
@@ -1893,6 +1900,25 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
session.flush()
+ return ret
+
+ def unmap(self) -> BaseOperator:
+ """Get the "normal" Operator after applying the current mapping"""
+ assert not isinstance(self.operator_class, str)
+
+ dag = self.get_dag()
+ if not dag:
+ raise RuntimeError("Cannot unmapp a task unless it has a dag")
+
+ args = {
+ **self.partial_kwargs,
+ **self.mapped_kwargs,
+ }
+ dag._remove_task(self.task_id)
+ task = self.operator_class(task_id=self.task_id, dag=self.dag, **args)
+
+ return task
+
# TODO: Deprecate for Airflow 3.0
Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 0528dbb..7f151f4d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -303,20 +303,23 @@ class TaskInstanceKey(NamedTuple):
task_id: str
run_id: str
try_number: int = 1
+ map_index: int = -1
@property
- def primary(self) -> Tuple[str, str, str]:
+ def primary(self) -> Tuple[str, str, str, int]:
"""Return task instance primary key part of the key"""
- return self.dag_id, self.task_id, self.run_id
+ return self.dag_id, self.task_id, self.run_id, self.map_index
@property
def reduced(self) -> 'TaskInstanceKey':
"""Remake the key by subtracting 1 from try number to match in memory information"""
- return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1))
+ return TaskInstanceKey(
+ self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1), self.map_index
+ )
def with_try_number(self, try_number: int) -> 'TaskInstanceKey':
"""Returns TaskInstanceKey with provided ``try_number``"""
- return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number)
+ return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number, self.map_index)
@property
def key(self) -> "TaskInstanceKey":
@@ -795,8 +798,6 @@ class TaskInstance(Base, LoggingMixin):
else:
self.state = None
- self.log.debug("Refreshed TaskInstance %s", self)
-
def refresh_from_task(self, task: "BaseOperator", pool_override=None):
"""
Copy common attributes from the given task.
@@ -829,12 +830,11 @@ class TaskInstance(Base, LoggingMixin):
execution_date=self.execution_date,
session=session,
)
- self.log.debug("XCom data cleared")
@property
def key(self) -> TaskInstanceKey:
"""Returns a tuple that identifies the task instance uniquely"""
- return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number)
+ return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)
@provide_session
def set_state(self, state: Optional[str], session=NEW_SESSION):
@@ -1068,7 +1068,10 @@ class TaskInstance(Base, LoggingMixin):
yield dep_status
def __repr__(self):
- return f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} [{self.state}]>"
+ prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} "
+ if self.map_index != -1:
+ prefix += f"map_index={self.map_index} "
+ return prefix + f"[{self.state}]>"
def next_retry_datetime(self):
"""
@@ -1312,6 +1315,11 @@ class TaskInstance(Base, LoggingMixin):
:param pool: specifies the pool to use to run the task instance
:param session: SQLAlchemy ORM Session
"""
+ if self.task.is_mapped:
+ raise RuntimeError(
+ f'task property of {self.task_id!r} was still a MappedOperator -- it should have been '
+ 'expanded already!'
+ )
self.test_mode = test_mode
self.refresh_from_task(self.task, pool_override=pool)
self.refresh_from_db(session=session)
@@ -1719,6 +1727,8 @@ class TaskInstance(Base, LoggingMixin):
self.refresh_from_db(session)
task = self.task
+ if task.is_mapped:
+ task = task.unmap()
self.end_date = timezone.utcnow()
self.set_duration()
Stats.incr(f'operator_failures_{task.task_type}', 1, 1)
@@ -2252,19 +2262,29 @@ class TaskInstance(Base, LoggingMixin):
dag_id = first.dag_id
run_id = first.run_id
+ map_index = first.map_index
first_task_id = first.task_id
# Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id
- # and task_id -- this can be over 150x for huge numbers of TIs (20k+)
- if all(t.dag_id == dag_id and t.run_id == run_id for t in tis):
+ # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+)
+ if all(t.dag_id == dag_id and t.run_id == run_id and t.map_index == map_index for t in tis):
return and_(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == run_id,
+ TaskInstance.map_index == map_index,
TaskInstance.task_id.in_(t.task_id for t in tis),
)
- if all(t.dag_id == dag_id and t.task_id == first_task_id for t in tis):
+ if all(t.dag_id == dag_id and t.task_id == first_task_id and t.map_index == map_index for t in tis):
return and_(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id.in_(t.run_id for t in tis),
+ TaskInstance.map_index == map_index,
+ TaskInstance.task_id == first_task_id,
+ )
+ if all(t.dag_id == dag_id and t.run_id == run_id and t.task_id == first_task_id for t in tis):
+ return and_(
+ TaskInstance.dag_id == dag_id,
+ TaskInstance.run_id == run_id,
+ TaskInstance.map_index.in_(t.map_index for t in tis),
TaskInstance.task_id == first_task_id,
)
@@ -2274,13 +2294,14 @@ class TaskInstance(Base, LoggingMixin):
TaskInstance.dag_id == ti.dag_id,
TaskInstance.task_id == ti.task_id,
TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
)
for ti in tis
)
else:
- return tuple_(TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id).in_(
- [ti.key.primary for ti in tis]
- )
+ return tuple_(
+ TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index
+ ).in_([ti.key.primary for ti in tis])
# State of the task instance.
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 4fc9566..7c06155 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -17,7 +17,7 @@
import warnings
from abc import ABCMeta, abstractmethod
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Set, Tuple, Union
+from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union
import pendulum
@@ -109,6 +109,8 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
"""
dag: Optional["DAG"] = None
+ task_group: Optional["TaskGroup"] = None
+ """The task_group that contains this node"""
@property
@abstractmethod
@@ -117,15 +119,12 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
@property
def label(self) -> Optional[str]:
- tg: Optional["TaskGroup"] = getattr(self, 'task_group', None)
+ tg = self.task_group
if tg and tg.node_id and tg.prefix_group_id:
# "task_group_id.task_id" -> "task_id"
return self.node_id[len(tg.node_id) + 1 :]
return self.node_id
- task_group: Optional["TaskGroup"]
- """The task_group that contains this node"""
-
start_date: Optional[pendulum.DateTime]
end_date: Optional[pendulum.DateTime]
upstream_task_ids: Set[str]
@@ -268,3 +267,46 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:
"""This is used by SerializedTaskGroup to serialize a task group's content."""
raise NotImplementedError()
+
+ def mapped_dependants(self) -> Iterator["DAGNode"]:
+ """Return any mapped nodes that are direct dependencies of the current task
+
+ For now, this walks the entire DAG to find mapped nodes that has this
+ current task as an upstream. We cannot use ``downstream_list`` since it
+ only contains operators, not task groups. In the future, we should
+ provide a way to record an DAG node's all downstream nodes instead.
+ """
+ from airflow.models.baseoperator import MappedOperator
+ from airflow.utils.task_group import MappedTaskGroup, TaskGroup
+
+ def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]:
+ """Recursively walk children in a task group.
+
+ This yields all direct children (including both tasks and task
+ groups), and all children of any task groups.
+ """
+ for key, child in group.children.items():
+ yield key, child
+ if isinstance(child, TaskGroup):
+ yield from _walk_group(child)
+
+ tg = self.task_group
+ if not tg:
+ raise RuntimeError("Cannot check for mapped_dependants when not attached to a DAG")
+ for key, child in _walk_group(tg):
+ if key == self.node_id:
+ continue
+ if not isinstance(child, (MappedOperator, MappedTaskGroup)):
+ continue
+ if self.node_id in child.upstream_task_ids:
+ yield child
+
+ def has_mapped_dependants(self) -> bool:
+ """Whether any downstream dependencies depend on this task for mapping.
+
+ For now, this walks the entire DAG to find mapped nodes that has this
+ current task as an upstream. We cannot use ``downstream_list`` since it
+ only contains operators, not task groups. In the future, we should
+ provide a way to record an DAG node's all downstream nodes instead.
+ """
+ return any(self.mapped_dependants())
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 63820ff..42fa314 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -19,6 +19,7 @@
import datetime
import enum
import logging
+import weakref
from dataclasses import dataclass
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Type, Union
@@ -567,7 +568,9 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
@classmethod
def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
- serialize_op = cls._serialize_node(op)
+
+ stock_deps = op.deps is MappedOperator.DEFAULT_DEPS
+ serialize_op = cls._serialize_node(op, include_deps=not stock_deps)
# It must be a class at this point for it to work, not a string
assert isinstance(op.operator_class, type)
serialize_op['_task_type'] = op.operator_class.__name__
@@ -577,10 +580,10 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
@classmethod
def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]:
- return cls._serialize_node(op)
+ return cls._serialize_node(op, include_deps=op.deps is not BaseOperator.deps)
@classmethod
- def _serialize_node(cls, op: Union[BaseOperator, MappedOperator]) -> Dict[str, Any]:
+ def _serialize_node(cls, op: Union[BaseOperator, MappedOperator], include_deps: bool) -> Dict[str, Any]:
"""Serializes operator into a JSON object."""
serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
serialize_op['_task_type'] = type(op).__name__
@@ -594,8 +597,8 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
op.operator_extra_links
)
- if op.deps is not BaseOperator.deps:
- # Are the deps different to BaseOperator, if so serialize the class names!
+ if include_deps:
+ # Are the deps different to "stock", if so serialize the class names!
# For Airflow 2.0 expediency we _only_ allow built in Dep classes.
# Fix this for 2.0.x or 2.1
deps = []
@@ -641,7 +644,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
# These are all re-set later
partial_kwargs={},
mapped_kwargs={},
- deps=tuple(),
+ deps=MappedOperator.DEFAULT_DEPS,
is_dummy=False,
template_fields=(),
template_ext=(),
@@ -1084,8 +1087,13 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization):
for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
}
group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, **kwargs)
+
+ def set_ref(task: BaseOperator) -> BaseOperator:
+ task.task_group = weakref.proxy(group)
+ return task
+
group.children = {
- label: task_dict[val] # type: ignore
+ label: set_ref(task_dict[val]) # type: ignore
if _type == DAT.OP # type: ignore
else SerializedTaskGroup.deserialize_task_group(val, group, task_dict)
for label, (_type, val) in encoded_group["children"].items()
diff --git a/tests/models/__init__.py b/airflow/ti_deps/deps/mapped_task_expanded.py
similarity index 60%
copy from tests/models/__init__.py
copy to airflow/ti_deps/deps/mapped_task_expanded.py
index 2d4a0d9..03cf07d 100644
--- a/tests/models/__init__.py
+++ b/airflow/ti_deps/deps/mapped_task_expanded.py
@@ -15,10 +15,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
-import os
-from airflow.utils import timezone
+class MappedTaskIsExpanded(BaseTIDep):
+ """Checks that a mapped task has been expanded before it's TaskInstance can run."""
-DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-TEST_DAGS_FOLDER = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'dags'))
+ NAME = "Task has been mapped"
+ IGNORABLE = False
+ IS_TASK_DEP = False
+
+ def _get_dep_statuses(self, ti, session, dep_context):
+ if ti.map_index == -1:
+ yield self._failing_status(reason="The task has yet to be mapped!")
+ return
+ yield self._passing_status(reason="The task has been mapped")
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 8f193f4..88b956e 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -92,7 +92,6 @@ class TaskGroup(DAGNode):
# used_group_ids is shared across all TaskGroups in the same DAG to keep track
# of used group_id to avoid duplication.
self.used_group_ids = set()
- self._parent_group = None
self.dag = dag
else:
if prefix_group_id:
@@ -108,28 +107,29 @@ class TaskGroup(DAGNode):
if not parent_group and not dag:
raise AirflowException("TaskGroup can only be used inside a dag")
- self._parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
- if not self._parent_group:
+ parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
+ if not parent_group:
raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup")
- if dag is not self._parent_group.dag:
+ if dag is not parent_group.dag:
raise RuntimeError(
- "Cannot mix TaskGroups from different DAGs: %s and %s", dag, self._parent_group.dag
+ "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag
)
- self.used_group_ids = self._parent_group.used_group_ids
+ self.used_group_ids = parent_group.used_group_ids
# if given group_id already used assign suffix by incrementing largest used suffix integer
# Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
self._group_id = group_id
self._check_for_group_id_collisions(add_suffix_on_collision)
+ self.children: Dict[str, DAGNode] = {}
+ if parent_group:
+ parent_group.add(self)
+
self.used_group_ids.add(self.group_id)
if self.group_id:
self.used_group_ids.add(self.downstream_join_id)
self.used_group_ids.add(self.upstream_join_id)
- self.children: Dict[str, DAGNode] = {}
- if self._parent_group:
- self._parent_group.add(self)
self.tooltip = tooltip
self.ui_color = ui_color
@@ -175,6 +175,10 @@ class TaskGroup(DAGNode):
"""Returns True if this TaskGroup is the root TaskGroup. Otherwise False"""
return not self.group_id
+ @property
+ def parent_group(self) -> Optional["TaskGroup"]:
+ return self.task_group
+
def __iter__(self):
for child in self.children.values():
if isinstance(child, TaskGroup):
@@ -184,6 +188,8 @@ class TaskGroup(DAGNode):
def add(self, task: DAGNode) -> None:
"""Add a task to this TaskGroup."""
+ # Set the TG first, as setting it might change the return value of node_id!
+ task.task_group = weakref.proxy(self)
key = task.node_id
if key in self.children:
@@ -201,7 +207,6 @@ class TaskGroup(DAGNode):
raise AirflowException("Cannot add a non-empty TaskGroup")
self.children[key] = task
- task.task_group = weakref.proxy(self)
def _remove(self, task: DAGNode) -> None:
key = task.node_id
@@ -216,8 +221,8 @@ class TaskGroup(DAGNode):
@property
def group_id(self) -> Optional[str]:
"""group_id of this TaskGroup."""
- if self._parent_group and self._parent_group.prefix_group_id and self._parent_group.group_id:
- return self._parent_group.child_id(self._group_id)
+ if self.task_group and self.task_group.prefix_group_id and self.task_group.group_id:
+ return self.task_group.child_id(self._group_id)
return self._group_id
@@ -380,8 +385,8 @@ class TaskGroup(DAGNode):
raise RuntimeError("Cannot map a TaskGroup that already has children")
if not self.group_id:
raise RuntimeError("Cannot map a TaskGroup before it has a group_id")
- if self._parent_group:
- self._parent_group._remove(self)
+ if self.task_group:
+ self.task_group._remove(self)
return MappedTaskGroup(group_id=self._group_id, dag=self.dag, mapped_arg=arg)
diff --git a/tests/models/__init__.py b/tests/dags/test_mapped_classic.py
similarity index 65%
copy from tests/models/__init__.py
copy to tests/dags/test_mapped_classic.py
index 2d4a0d9..14f2c1f 100644
--- a/tests/models/__init__.py
+++ b/tests/dags/test_mapped_classic.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,9 +15,20 @@
# specific language governing permissions and limitations
# under the License.
-import os
+from airflow import DAG
+from airflow.decorators import task
+from airflow.operators.python import PythonOperator
+from airflow.utils.dates import days_ago
+
+
+@task
+def make_list():
+ return [1, 2, {'a': 'b'}]
+
+
+def consumer(*args):
+ print(repr(args))
-from airflow.utils import timezone
-DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-TEST_DAGS_FOLDER = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'dags'))
+with DAG(dag_id='test_mapped_classic', start_date=days_ago(2)) as dag:
+ PythonOperator(task_id='consumer', python_callable=consumer).map(op_args=make_list())
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index 340e67b..af79b16 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -29,6 +29,7 @@ from kubernetes.client.rest import ApiException
from urllib3 import HTTPResponse
from airflow import AirflowException
+from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils import timezone
from tests.test_utils.config import conf_vars
@@ -244,7 +245,7 @@ class TestKubernetesExecutor:
kubernetes_executor.start()
# Execute a task while the Api Throws errors
try_number = 1
- task_instance_key = ('dag', 'task', 'run_id', try_number)
+ task_instance_key = TaskInstanceKey('dag', 'task', 'run_id', try_number)
kubernetes_executor.execute_async(
key=task_instance_key,
queue=None,
@@ -326,7 +327,7 @@ class TestKubernetesExecutor:
assert executor.task_queue.empty()
executor.execute_async(
- key=('dag', 'task', 'run_id', 1),
+ key=TaskInstanceKey('dag', 'task', 'run_id', 1),
queue=None,
command=['airflow', 'tasks', 'run', 'true', 'some_parameter'],
executor_config={
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index 37b5acc..0878f63 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -41,10 +41,12 @@ from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstanceKey
from airflow.operators.dummy import DummyOperator
from airflow.utils import timezone
+from airflow.utils.dates import days_ago
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.types import DagRunType
+from tests.models import TEST_DAGS_FOLDER
from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots
from tests.test_utils.mock_executor import MockExecutor
from tests.test_utils.timetables import cron_timetable
@@ -190,7 +192,7 @@ class TestBackfillJob:
("run_this_last", end_date),
]
assert [
- ((dag.dag_id, task_id, f'backfill__{when.isoformat()}', 1), (State.SUCCESS, None))
+ ((dag.dag_id, task_id, f'backfill__{when.isoformat()}', 1, -1), (State.SUCCESS, None))
for (task_id, when) in expected_execution_order
] == executor.sorted_tasks
@@ -267,7 +269,7 @@ class TestBackfillJob:
job.run()
assert [
- ((dag_id, task_id, f'backfill__{DEFAULT_DATE.isoformat()}', 1), (State.SUCCESS, None))
+ ((dag_id, task_id, f'backfill__{DEFAULT_DATE.isoformat()}', 1, -1), (State.SUCCESS, None))
for task_id in expected_execution_order
] == executor.sorted_tasks
@@ -1230,12 +1232,11 @@ class TestBackfillJob:
subdag.clear()
dag.clear()
- def test_update_counters(self, dag_maker):
- with dag_maker(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE) as dag:
+ def test_update_counters(self, dag_maker, session):
+ with dag_maker(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE, session=session) as dag:
task1 = DummyOperator(task_id='dummy', owner='airflow')
dr = dag_maker.create_dagrun()
job = BackfillJob(dag=dag)
- session = settings.Session()
ti = TI(task1, dr.execution_date)
ti.refresh_from_db()
@@ -1245,7 +1246,7 @@ class TestBackfillJob:
# test for success
ti.set_state(State.SUCCESS, session)
ti_status.running[ti.key] = ti
- job._update_counters(ti_status=ti_status)
+ job._update_counters(ti_status=ti_status, session=session)
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 1
assert len(ti_status.skipped) == 0
@@ -1257,7 +1258,7 @@ class TestBackfillJob:
# test for skipped
ti.set_state(State.SKIPPED, session)
ti_status.running[ti.key] = ti
- job._update_counters(ti_status=ti_status)
+ job._update_counters(ti_status=ti_status, session=session)
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 0
assert len(ti_status.skipped) == 1
@@ -1269,7 +1270,7 @@ class TestBackfillJob:
# test for failed
ti.set_state(State.FAILED, session)
ti_status.running[ti.key] = ti
- job._update_counters(ti_status=ti_status)
+ job._update_counters(ti_status=ti_status, session=session)
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 0
assert len(ti_status.skipped) == 0
@@ -1281,7 +1282,7 @@ class TestBackfillJob:
# test for retry
ti.set_state(State.UP_FOR_RETRY, session)
ti_status.running[ti.key] = ti
- job._update_counters(ti_status=ti_status)
+ job._update_counters(ti_status=ti_status, session=session)
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 0
assert len(ti_status.skipped) == 0
@@ -1297,7 +1298,7 @@ class TestBackfillJob:
ti.set_state(State.UP_FOR_RESCHEDULE, session)
assert ti.try_number == 3 # see ti.try_number property in taskinstance module
ti_status.running[ti.key] = ti
- job._update_counters(ti_status=ti_status)
+ job._update_counters(ti_status=ti_status, session=session)
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 0
assert len(ti_status.skipped) == 0
@@ -1315,7 +1316,7 @@ class TestBackfillJob:
session.merge(ti)
session.commit()
ti_status.running[ti.key] = ti
- job._update_counters(ti_status=ti_status)
+ job._update_counters(ti_status=ti_status, session=session)
assert len(ti_status.running) == 0
assert len(ti_status.succeeded) == 0
assert len(ti_status.skipped) == 0
@@ -1510,3 +1511,22 @@ class TestBackfillJob:
)
job.run()
assert executor.job_id is not None
+
+ def test_mapped_dag(self, dag_maker):
+ """End-to-end test of a simple mapped dag"""
+ # Use SequentialExecutor for more predictable test behaviour
+ from airflow.executors.sequential_executor import SequentialExecutor
+
+ self.dagbag.process_file(str(TEST_DAGS_FOLDER / 'test_mapped_classic.py'))
+ dag = self.dagbag.get_dag('test_mapped_classic')
+
+ # This needs a real executor to run, so that the `make_list` task can write out the TaskMap
+
+ job = BackfillJob(
+ dag=dag,
+ start_date=days_ago(1),
+ end_date=days_ago(1),
+ donot_pickle=True,
+ executor=SequentialExecutor(),
+ )
+ job.run()
diff --git a/tests/models/__init__.py b/tests/models/__init__.py
index 2d4a0d9..c1cbabd 100644
--- a/tests/models/__init__.py
+++ b/tests/models/__init__.py
@@ -16,9 +16,9 @@
# specific language governing permissions and limitations
# under the License.
-import os
+import pathlib
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-TEST_DAGS_FOLDER = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'dags'))
+TEST_DAGS_FOLDER = pathlib.Path(__file__).parent.with_name('dags')
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index ffde6ee..9f3e8ad 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -718,6 +718,7 @@ def test_task_mapping_with_dag():
assert task1.downstream_list == [mapped]
assert mapped in dag.tasks
+ assert mapped.task_group == dag.task_group
# At parse time there should only be three tasks!
assert len(dag.tasks) == 3
@@ -799,11 +800,21 @@ def test_partial_on_class_invalid_ctor_args() -> None:
["num_existing_tis", "expected"],
(
pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'),
- pytest.param(3, [(0, None), (1, None), (2, None)], id='all-tis-exist'),
+ pytest.param(
+ 3,
+ [(0, 'success'), (1, 'success'), (2, 'success')],
+ id='all-tis-exist',
+ ),
pytest.param(
5,
- [(0, None), (1, None), (2, None), (3, TaskInstanceState.REMOVED), (4, TaskInstanceState.REMOVED)],
- id="tis-to-be-remove",
+ [
+ (0, 'success'),
+ (1, 'success'),
+ (2, 'success'),
+ (3, TaskInstanceState.REMOVED),
+ (4, TaskInstanceState.REMOVED),
+ ],
+ id="tis-to-be-removed",
),
),
)
@@ -836,7 +847,8 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec
).delete()
for index in range(num_existing_tis):
- ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index) # type: ignore
+ # Give the existing TIs a state to make sure we don't change them
+ ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS)
session.add(ti)
session.flush()
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index f7e40ef..1eb3328 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -2255,7 +2255,7 @@ def test_set_task_instance_state(run_id, execution_date, session, dag_maker):
# dagrun should be set to QUEUED
assert dagrun.get_state() == State.QUEUED
- assert {t.key for t in altered} == {('test_set_task_instance_state', 'task_1', dagrun.run_id, 1)}
+ assert {t.key for t in altered} == {('test_set_task_instance_state', 'task_1', dagrun.run_id, 1, -1)}
@pytest.mark.parametrize(
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index ff4207a..bbf05c3 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1947,7 +1947,7 @@ class TestTaskInstance:
with dag_maker('test-dag', session=session) as dag:
task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
- dag.fileloc = TEST_DAGS_FOLDER + '/test_get_k8s_pod_yaml.py'
+ dag.fileloc = TEST_DAGS_FOLDER / 'test_get_k8s_pod_yaml.py'
ti = dag_maker.create_dagrun().task_instances[0]
ti.task = task
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 35d0d68..447b173 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -488,6 +488,9 @@ class TestStringifiedDAGs:
assert not isinstance(task, SerializedBaseOperator)
assert isinstance(task, BaseOperator)
+ # Every task should have a task_group property -- even if it's the DAG's root task group
+ assert serialized_task.task_group
+
fields_to_check = task.get_serialized_fields() - {
# Checked separately
'_task_type',
@@ -1608,6 +1611,7 @@ def test_mapped_operator_serde():
op = SerializedBaseOperator.deserialize_operator(serialized)
assert isinstance(op, MappedOperator)
+ assert op.deps is MappedOperator.DEFAULT_DEPS
assert op.operator_class == "airflow.operators.bash.BashOperator"
assert op.mapped_kwargs['bash_command'] == literal
@@ -1637,6 +1641,7 @@ def test_mapped_operator_xcomarg_serde():
}
op = SerializedBaseOperator.deserialize_operator(serialized)
+ assert op.deps is MappedOperator.DEFAULT_DEPS
arg = op.mapped_kwargs['arg2']
assert arg.task_id == 'op1'
diff --git a/tests/test_utils/mock_executor.py b/tests/test_utils/mock_executor.py
index 37f49cf..23d32b6 100644
--- a/tests/test_utils/mock_executor.py
+++ b/tests/test_utils/mock_executor.py
@@ -59,10 +59,10 @@ class MockExecutor(BaseExecutor):
# for tests!
def sort_by(item):
key, val = item
- (dag_id, task_id, date, try_number) = key
+ (dag_id, task_id, date, try_number, map_index) = key
(_, prio, _, _) = val
# Sort by priority (DESC), then date,task, try
- return -prio, date, dag_id, task_id, try_number
+ return -prio, date, dag_id, task_id, map_index, try_number
open_slots = self.parallelism - len(self.running)
sorted_queue = sorted(self.queued_tasks.items(), key=sort_by)