You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/12 18:56:03 UTC

[airflow] 01/06: Allow marking/clearing mapped taskinstances from the UI

This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7a78a3efdb6820bcf4b3ee3fef27d8c1e4454b5d
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Apr 12 16:37:35 2022 +0100

    Allow marking/clearing mapped taskinstances from the UI
---
 airflow/api/common/mark_tasks.py |  7 ++++++-
 airflow/models/dag.py            | 18 ++++++++++++++++--
 airflow/www/views.py             | 21 +++++++++++++++++++++
 3 files changed, 43 insertions(+), 3 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index d11f490247..fe9fa0f490 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -79,6 +79,7 @@ def _create_dagruns(
 def set_state(
     *,
     tasks: Iterable[Operator],
+    map_indexes: Optional[Iterable[int]] = None,
     run_id: Optional[str] = None,
     execution_date: Optional[datetime] = None,
     upstream: bool = False,
@@ -97,6 +98,7 @@ def set_state(
     on the schedule (but it will as for subdag dag runs if needed).
 
     :param tasks: the iterable of tasks from which to work. task.task.dag needs to be set
+    :param map_indexes: the map indexes of the tasks to set
     :param run_id: the run_id of the dagrun to start looking from
     :param execution_date: the execution date from which to start looking(deprecated)
     :param upstream: Mark all parents (upstream tasks)
@@ -143,7 +145,7 @@ def set_state(
 
     # now look for the task instances that are affected
 
-    qry_dag = get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates)
+    qry_dag = get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates, map_indexes)
 
     if commit:
         tis_altered = qry_dag.with_for_update().all()
@@ -181,6 +183,7 @@ def get_all_dag_task_query(
     state: TaskInstanceState,
     task_ids: List[str],
     confirmed_dates: Iterable[datetime],
+    map_indexes: Optional[Iterable[int]] = None,
 ):
     """Get all tasks of the main dag that will be affected by a state change"""
     qry_dag = (
@@ -194,6 +197,8 @@ def get_all_dag_task_query(
         .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
         .options(contains_eager(TaskInstance.dag_run))
     )
+    if map_indexes:
+        qry_dag = qry_dag.filter(TaskInstance.map_index.in_(map_indexes))
     return qry_dag
 
 
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 8d5e8eacd6..8efedbef60 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1348,6 +1348,7 @@ class DAG(LoggingMixin):
                 Query,
                 self._get_task_instances(
                     task_ids=None,
+                    map_indexes=None,
                     start_date=start_date,
                     end_date=end_date,
                     run_id=None,
@@ -1368,6 +1369,7 @@ class DAG(LoggingMixin):
         self,
         *,
         task_ids,
+        map_indexes: Optional[Iterable[int]] = None,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
@@ -1386,6 +1388,7 @@ class DAG(LoggingMixin):
         self,
         *,
         task_ids,
+        map_indexes: Optional[Iterable[int]] = None,
         as_pk_tuple: Literal[True],
         start_date: Optional[datetime],
         end_date: Optional[datetime],
@@ -1407,6 +1410,7 @@ class DAG(LoggingMixin):
         self,
         *,
         task_ids,
+        map_indexes: Optional[Iterable[int]] = None,
         as_pk_tuple: Literal[True, None] = None,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
@@ -1434,7 +1438,7 @@ class DAG(LoggingMixin):
 
         # Do we want full objects, or just the primary columns?
         if as_pk_tuple:
-            tis = session.query(TI.dag_id, TI.task_id, TI.run_id)
+            tis = session.query(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
         else:
             tis = session.query(TaskInstance)
         tis = tis.join(TaskInstance.dag_run)
@@ -1455,6 +1459,8 @@ class DAG(LoggingMixin):
             tis = tis.filter(DagRun.execution_date >= start_date)
         if task_ids:
             tis = tis.filter(TaskInstance.task_id.in_(task_ids))
+        if map_indexes:
+            tis = tis.filter(TaskInstance.map_index.in_(map_indexes))
 
         # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
         if end_date or not self.allow_future_exec_dates:
@@ -1493,6 +1499,7 @@ class DAG(LoggingMixin):
             result.update(
                 p_dag._get_task_instances(
                     task_ids=task_ids,
+                    map_indexes=map_indexes,
                     start_date=start_date,
                     end_date=end_date,
                     run_id=None,
@@ -1570,6 +1577,7 @@ class DAG(LoggingMixin):
                     result.update(
                         downstream._get_task_instances(
                             task_ids=None,
+                            map_indexes=None,
                             run_id=tii.run_id,
                             start_date=None,
                             end_date=None,
@@ -1606,7 +1614,7 @@ class DAG(LoggingMixin):
             return result
         elif result:
             # We've been asked for objects, lets combine it all back in to a result set
-            tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id)
+            tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
 
             tis = session.query(TI).filter(TI.filter_for_tis(result))
         elif exclude_task_ids:
@@ -1619,6 +1627,7 @@ class DAG(LoggingMixin):
         self,
         *,
         task_id: str,
+        map_indexes: Optional[Iterable[int]] = None,
         execution_date: Optional[datetime] = None,
         run_id: Optional[str] = None,
         state: TaskInstanceState,
@@ -1634,6 +1643,7 @@ class DAG(LoggingMixin):
         in failed or upstream_failed state.
 
         :param task_id: Task ID of the TaskInstance
+        :param map_indexes: Task instance map_index to set the state of
         :param execution_date: Execution date of the TaskInstance
         :param run_id: The run_id of the TaskInstance
         :param state: State to set the TaskInstance to
@@ -1661,6 +1671,7 @@ class DAG(LoggingMixin):
 
         altered = set_state(
             tasks=[task],
+            map_indexes=map_indexes,
             execution_date=execution_date,
             run_id=run_id,
             upstream=upstream,
@@ -1754,6 +1765,7 @@ class DAG(LoggingMixin):
     def clear(
         self,
         task_ids=None,
+        map_indexes: Optional[Iterable[int]] = None,
         start_date: Optional[datetime] = None,
         end_date: Optional[datetime] = None,
         only_failed: bool = False,
@@ -1775,6 +1787,7 @@ class DAG(LoggingMixin):
         a specified date range.
 
         :param task_ids: List of task ids to clear
+        :param map_indexes: List of map_indexes to clear
         :param start_date: The minimum execution_date to clear
         :param end_date: The maximum execution_date to clear
         :param only_failed: Only clear failed tasks
@@ -1820,6 +1833,7 @@ class DAG(LoggingMixin):
 
         tis = self._get_task_instances(
             task_ids=task_ids,
+            map_indexes=map_indexes,
             start_date=start_date,
             end_date=end_date,
             run_id=None,
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 953164e94a..ad9378f462 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1980,6 +1980,7 @@ class Airflow(AirflowBaseView):
         start_date,
         end_date,
         origin,
+        map_indexes=None,
         recursive=False,
         confirmed=False,
         only_failed=False,
@@ -1988,6 +1989,7 @@ class Airflow(AirflowBaseView):
             count = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
+                map_indexes=map_indexes,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -2000,6 +2002,7 @@ class Airflow(AirflowBaseView):
             tis = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
+                map_indexes=map_indexes,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -2041,6 +2044,9 @@ class Airflow(AirflowBaseView):
         task_id = request.form.get('task_id')
         origin = get_safe_url(request.form.get('origin'))
         dag = current_app.dag_bag.get_dag(dag_id)
+        map_indexes = request.form.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
 
         execution_date = request.form.get('execution_date')
         execution_date = timezone.parse(execution_date)
@@ -2065,6 +2071,7 @@ class Airflow(AirflowBaseView):
             start_date,
             end_date,
             origin,
+            map_indexes=map_indexes,
             recursive=recursive,
             confirmed=confirmed,
             only_failed=only_failed,
@@ -2083,6 +2090,9 @@ class Airflow(AirflowBaseView):
         dag_id = request.form.get('dag_id')
         dag_run_id = request.form.get('dag_run_id')
         confirmed = request.form.get('confirmed') == "true"
+        map_indexes = request.form.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
 
         dag = current_app.dag_bag.get_dag(dag_id)
         dr = dag.get_dagrun(run_id=dag_run_id)
@@ -2093,6 +2103,7 @@ class Airflow(AirflowBaseView):
             dag,
             start_date,
             end_date,
+            map_indexes=map_indexes,
             origin=None,
             recursive=True,
             confirmed=confirmed,
@@ -2299,6 +2310,7 @@ class Airflow(AirflowBaseView):
         self,
         dag_id,
         task_id,
+        map_indexes,
         origin,
         dag_run_id,
         upstream,
@@ -2316,6 +2328,7 @@ class Airflow(AirflowBaseView):
 
         altered = dag.set_task_instance_state(
             task_id=task_id,
+            map_index=map_indexes,
             run_id=dag_run_id,
             state=state,
             upstream=upstream,
@@ -2418,6 +2431,9 @@ class Airflow(AirflowBaseView):
         task_id = args.get('task_id')
         origin = get_safe_url(args.get('origin'))
         dag_run_id = args.get('dag_run_id')
+        map_indexes = args.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2427,6 +2443,7 @@ class Airflow(AirflowBaseView):
         return self._mark_task_instance_state(
             dag_id,
             task_id,
+            map_indexes,
             origin,
             dag_run_id,
             upstream,
@@ -2451,6 +2468,9 @@ class Airflow(AirflowBaseView):
         task_id = args.get('task_id')
         origin = get_safe_url(args.get('origin'))
         dag_run_id = args.get('dag_run_id')
+        map_indexes = args.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2460,6 +2480,7 @@ class Airflow(AirflowBaseView):
         return self._mark_task_instance_state(
             dag_id,
             task_id,
+            map_indexes,
             origin,
             dag_run_id,
             upstream,