You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/12 18:56:03 UTC
[airflow] 01/06: Allow marking/clearing mapped taskinstances from the UI
This is an automated email from the ASF dual-hosted git repository.
bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 7a78a3efdb6820bcf4b3ee3fef27d8c1e4454b5d
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Apr 12 16:37:35 2022 +0100
Allow marking/clearing mapped taskinstances from the UI
---
airflow/api/common/mark_tasks.py | 7 ++++++-
airflow/models/dag.py | 18 ++++++++++++++++--
airflow/www/views.py | 21 +++++++++++++++++++++
3 files changed, 43 insertions(+), 3 deletions(-)
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index d11f490247..fe9fa0f490 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -79,6 +79,7 @@ def _create_dagruns(
def set_state(
*,
tasks: Iterable[Operator],
+ map_indexes: Optional[Iterable[int]] = None,
run_id: Optional[str] = None,
execution_date: Optional[datetime] = None,
upstream: bool = False,
@@ -97,6 +98,7 @@ def set_state(
on the schedule (but it will as for subdag dag runs if needed).
:param tasks: the iterable of tasks from which to work. task.task.dag needs to be set
+ :param map_indexes: the map indexes of the tasks to set
:param run_id: the run_id of the dagrun to start looking from
:param execution_date: the execution date from which to start looking(deprecated)
:param upstream: Mark all parents (upstream tasks)
@@ -143,7 +145,7 @@ def set_state(
# now look for the task instances that are affected
- qry_dag = get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates)
+ qry_dag = get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates, map_indexes)
if commit:
tis_altered = qry_dag.with_for_update().all()
@@ -181,6 +183,7 @@ def get_all_dag_task_query(
state: TaskInstanceState,
task_ids: List[str],
confirmed_dates: Iterable[datetime],
+ map_indexes: Optional[Iterable[int]] = None,
):
"""Get all tasks of the main dag that will be affected by a state change"""
qry_dag = (
@@ -194,6 +197,8 @@ def get_all_dag_task_query(
.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
.options(contains_eager(TaskInstance.dag_run))
)
+ if map_indexes:
+ qry_dag = qry_dag.filter(TaskInstance.map_index.in_(map_indexes))
return qry_dag
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 8d5e8eacd6..8efedbef60 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1348,6 +1348,7 @@ class DAG(LoggingMixin):
Query,
self._get_task_instances(
task_ids=None,
+ map_indexes=None,
start_date=start_date,
end_date=end_date,
run_id=None,
@@ -1368,6 +1369,7 @@ class DAG(LoggingMixin):
self,
*,
task_ids,
+ map_indexes: Optional[Iterable[int]] = None,
start_date: Optional[datetime],
end_date: Optional[datetime],
run_id: Optional[str],
@@ -1386,6 +1388,7 @@ class DAG(LoggingMixin):
self,
*,
task_ids,
+ map_indexes: Optional[Iterable[int]] = None,
as_pk_tuple: Literal[True],
start_date: Optional[datetime],
end_date: Optional[datetime],
@@ -1407,6 +1410,7 @@ class DAG(LoggingMixin):
self,
*,
task_ids,
+ map_indexes: Optional[Iterable[int]] = None,
as_pk_tuple: Literal[True, None] = None,
start_date: Optional[datetime],
end_date: Optional[datetime],
@@ -1434,7 +1438,7 @@ class DAG(LoggingMixin):
# Do we want full objects, or just the primary columns?
if as_pk_tuple:
- tis = session.query(TI.dag_id, TI.task_id, TI.run_id)
+ tis = session.query(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
else:
tis = session.query(TaskInstance)
tis = tis.join(TaskInstance.dag_run)
@@ -1455,6 +1459,8 @@ class DAG(LoggingMixin):
tis = tis.filter(DagRun.execution_date >= start_date)
if task_ids:
tis = tis.filter(TaskInstance.task_id.in_(task_ids))
+ if map_indexes:
+ tis = tis.filter(TaskInstance.map_index.in_(map_indexes))
# This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
if end_date or not self.allow_future_exec_dates:
@@ -1493,6 +1499,7 @@ class DAG(LoggingMixin):
result.update(
p_dag._get_task_instances(
task_ids=task_ids,
+ map_indexes=map_indexes,
start_date=start_date,
end_date=end_date,
run_id=None,
@@ -1570,6 +1577,7 @@ class DAG(LoggingMixin):
result.update(
downstream._get_task_instances(
task_ids=None,
+ map_indexes=None,
run_id=tii.run_id,
start_date=None,
end_date=None,
@@ -1606,7 +1614,7 @@ class DAG(LoggingMixin):
return result
elif result:
# We've been asked for objects, lets combine it all back in to a result set
- tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id)
+ tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
tis = session.query(TI).filter(TI.filter_for_tis(result))
elif exclude_task_ids:
@@ -1619,6 +1627,7 @@ class DAG(LoggingMixin):
self,
*,
task_id: str,
+ map_indexes: Optional[Iterable[int]] = None,
execution_date: Optional[datetime] = None,
run_id: Optional[str] = None,
state: TaskInstanceState,
@@ -1634,6 +1643,7 @@ class DAG(LoggingMixin):
in failed or upstream_failed state.
:param task_id: Task ID of the TaskInstance
+ :param map_indexes: Task instance map_index to set the state of
:param execution_date: Execution date of the TaskInstance
:param run_id: The run_id of the TaskInstance
:param state: State to set the TaskInstance to
@@ -1661,6 +1671,7 @@ class DAG(LoggingMixin):
altered = set_state(
tasks=[task],
+ map_indexes=map_indexes,
execution_date=execution_date,
run_id=run_id,
upstream=upstream,
@@ -1754,6 +1765,7 @@ class DAG(LoggingMixin):
def clear(
self,
task_ids=None,
+ map_indexes: Optional[Iterable[int]] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
only_failed: bool = False,
@@ -1775,6 +1787,7 @@ class DAG(LoggingMixin):
a specified date range.
:param task_ids: List of task ids to clear
+ :param map_indexes: List of map_indexes to clear
:param start_date: The minimum execution_date to clear
:param end_date: The maximum execution_date to clear
:param only_failed: Only clear failed tasks
@@ -1820,6 +1833,7 @@ class DAG(LoggingMixin):
tis = self._get_task_instances(
task_ids=task_ids,
+ map_indexes=map_indexes,
start_date=start_date,
end_date=end_date,
run_id=None,
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 953164e94a..ad9378f462 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1980,6 +1980,7 @@ class Airflow(AirflowBaseView):
start_date,
end_date,
origin,
+ map_indexes=None,
recursive=False,
confirmed=False,
only_failed=False,
@@ -1988,6 +1989,7 @@ class Airflow(AirflowBaseView):
count = dag.clear(
start_date=start_date,
end_date=end_date,
+ map_indexes=map_indexes,
include_subdags=recursive,
include_parentdag=recursive,
only_failed=only_failed,
@@ -2000,6 +2002,7 @@ class Airflow(AirflowBaseView):
tis = dag.clear(
start_date=start_date,
end_date=end_date,
+ map_indexes=map_indexes,
include_subdags=recursive,
include_parentdag=recursive,
only_failed=only_failed,
@@ -2041,6 +2044,9 @@ class Airflow(AirflowBaseView):
task_id = request.form.get('task_id')
origin = get_safe_url(request.form.get('origin'))
dag = current_app.dag_bag.get_dag(dag_id)
+ map_indexes = request.form.get('map_indexes')
+ if map_indexes and not isinstance(map_indexes, list):
+ map_indexes = list(map_indexes)
execution_date = request.form.get('execution_date')
execution_date = timezone.parse(execution_date)
@@ -2065,6 +2071,7 @@ class Airflow(AirflowBaseView):
start_date,
end_date,
origin,
+ map_indexes=map_indexes,
recursive=recursive,
confirmed=confirmed,
only_failed=only_failed,
@@ -2083,6 +2090,9 @@ class Airflow(AirflowBaseView):
dag_id = request.form.get('dag_id')
dag_run_id = request.form.get('dag_run_id')
confirmed = request.form.get('confirmed') == "true"
+ map_indexes = request.form.get('map_indexes')
+ if map_indexes and not isinstance(map_indexes, list):
+ map_indexes = list(map_indexes)
dag = current_app.dag_bag.get_dag(dag_id)
dr = dag.get_dagrun(run_id=dag_run_id)
@@ -2093,6 +2103,7 @@ class Airflow(AirflowBaseView):
dag,
start_date,
end_date,
+ map_indexes=map_indexes,
origin=None,
recursive=True,
confirmed=confirmed,
@@ -2299,6 +2310,7 @@ class Airflow(AirflowBaseView):
self,
dag_id,
task_id,
+ map_indexes,
origin,
dag_run_id,
upstream,
@@ -2316,6 +2328,7 @@ class Airflow(AirflowBaseView):
altered = dag.set_task_instance_state(
task_id=task_id,
+ map_index=map_indexes,
run_id=dag_run_id,
state=state,
upstream=upstream,
@@ -2418,6 +2431,9 @@ class Airflow(AirflowBaseView):
task_id = args.get('task_id')
origin = get_safe_url(args.get('origin'))
dag_run_id = args.get('dag_run_id')
+ map_indexes = args.get('map_indexes')
+ if map_indexes and not isinstance(map_indexes, list):
+ map_indexes = list(map_indexes)
upstream = to_boolean(args.get('upstream'))
downstream = to_boolean(args.get('downstream'))
@@ -2427,6 +2443,7 @@ class Airflow(AirflowBaseView):
return self._mark_task_instance_state(
dag_id,
task_id,
+ map_indexes,
origin,
dag_run_id,
upstream,
@@ -2451,6 +2468,9 @@ class Airflow(AirflowBaseView):
task_id = args.get('task_id')
origin = get_safe_url(args.get('origin'))
dag_run_id = args.get('dag_run_id')
+ map_indexes = args.get('map_indexes')
+ if map_indexes and not isinstance(map_indexes, list):
+ map_indexes = list(map_indexes)
upstream = to_boolean(args.get('upstream'))
downstream = to_boolean(args.get('downstream'))
@@ -2460,6 +2480,7 @@ class Airflow(AirflowBaseView):
return self._mark_task_instance_state(
dag_id,
task_id,
+ map_indexes,
origin,
dag_run_id,
upstream,