You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/20 14:31:32 UTC

[airflow] 01/09: Accept multiple map_index param from front end

This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 4ffae5cf1fa8d460451c875c0940c57d91187b43
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Apr 19 09:16:04 2022 +0800

    Accept multiple map_index param from front end
    
    This allows setting multiple instances of the same task to SUCCESS or
    FAILED in one request. This is translated to multiple task specifier
    tuples (task_id, map_index) when passed to set_state().
    
    Also made some drive-through improvements adding types and clean some
    formatting up.
---
 airflow/api/common/mark_tasks.py    |   4 +-
 airflow/models/dag.py               |  12 ++---
 airflow/www/views.py                | 105 ++++++++++++++++++++----------------
 tests/www/views/test_views.py       |   3 +-
 tests/www/views/test_views_tasks.py |   2 +-
 5 files changed, 71 insertions(+), 55 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 594423305c..349b935e82 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -18,7 +18,7 @@
 """Marks tasks APIs."""
 
 from datetime import datetime
-from typing import TYPE_CHECKING, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
 
 from sqlalchemy import or_, tuple_
 from sqlalchemy.orm import contains_eager
@@ -78,7 +78,7 @@ def _create_dagruns(
 @provide_session
 def set_state(
     *,
-    tasks: Union[Iterable[Operator], Iterable[Tuple[Operator, int]]],
+    tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]],
     run_id: Optional[str] = None,
     execution_date: Optional[datetime] = None,
     upstream: bool = False,
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 755505b5d0..9c93bcef13 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1620,7 +1620,7 @@ class DAG(LoggingMixin):
         self,
         *,
         task_id: str,
-        map_index: Optional[int] = None,
+        map_indexes: Optional[Collection[int]] = None,
         execution_date: Optional[datetime] = None,
         run_id: Optional[str] = None,
         state: TaskInstanceState,
@@ -1636,8 +1636,8 @@ class DAG(LoggingMixin):
         in failed or upstream_failed state.
 
         :param task_id: Task ID of the TaskInstance
-        :param map_index: The TaskInstance map_index, if None, would set state for all mapped
-            TaskInstances of the task
+        :param map_indexes: Only set TaskInstance if its map_index matches.
+            If None (default), all mapped TaskInstances of the task are set.
         :param execution_date: Execution date of the TaskInstance
         :param run_id: The run_id of the TaskInstance
         :param state: State to set the TaskInstance to
@@ -1665,12 +1665,12 @@ class DAG(LoggingMixin):
 
         tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]]
         task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]]
-        if map_index is None:
+        if map_indexes is None:
             tasks_to_set_state = [task]
             task_ids_to_exclude_from_clear = {task_id}
         else:
-            tasks_to_set_state = [(task, map_index)]
-            task_ids_to_exclude_from_clear = {(task_id, map_index)}
+            tasks_to_set_state = [(task, map_index) for map_index in map_indexes]
+            task_ids_to_exclude_from_clear = {(task_id, map_index) for map_index in map_indexes}
 
         altered = set_state(
             tasks=tasks_to_set_state,
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 437a60cca0..aac30f64ff 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -95,6 +95,7 @@ from airflow.api.common.mark_tasks import (
     set_dag_run_state_to_failed,
     set_dag_run_state_to_queued,
     set_dag_run_state_to_success,
+    set_state,
 )
 from airflow.compat.functools import cached_property
 from airflow.configuration import AIRFLOW_CONFIG, conf
@@ -107,6 +108,7 @@ from airflow.models import DAG, Connection, DagModel, DagTag, Log, SlaMiss, Task
 from airflow.models.abstractoperator import AbstractOperator
 from airflow.models.dagcode import DagCode
 from airflow.models.dagrun import DagRun, DagRunType
+from airflow.models.operator import Operator
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskinstance import TaskInstance
 from airflow.providers_manager import ProvidersManager
@@ -2284,28 +2286,28 @@ class Airflow(AirflowBaseView):
 
     def _mark_task_instance_state(
         self,
-        dag_id,
-        task_id,
-        origin,
-        dag_run_id,
-        upstream,
-        downstream,
-        future,
-        past,
-        state,
-        map_index=None,
+        *,
+        dag_id: str,
+        run_id: str,
+        task_id: str,
+        map_indexes: Optional[List[int]],
+        origin: str,
+        upstream: bool,
+        downstream: bool,
+        future: bool,
+        past: bool,
+        state: TaskInstanceState,
     ):
         dag = current_app.dag_bag.get_dag(dag_id)
-        latest_execution_date = dag.get_latest_execution_date()
 
-        if not latest_execution_date:
-            flash(f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run", "error")
+        if not run_id:
+            flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has never run", "error")
             return redirect(origin)
 
         altered = dag.set_task_instance_state(
             task_id=task_id,
-            map_index=map_index,
-            run_id=dag_run_id,
+            map_indexes=map_indexes,
+            run_id=run_id,
             state=state,
             upstream=upstream,
             downstream=downstream,
@@ -2332,7 +2334,11 @@ class Airflow(AirflowBaseView):
         dag_run_id = args.get('dag_run_id')
         state = args.get('state')
         origin = args.get('origin')
-        map_index = args.get('map_index')
+
+        if 'map_index' not in args:
+            map_indexes: Optional[List[int]] = None
+        else:
+            map_indexes = args.getlist('map_index', type=int)
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2365,9 +2371,10 @@ class Airflow(AirflowBaseView):
             msg = f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run"
             return redirect_or_json(origin, msg, status='error')
 
-        from airflow.api.common.mark_tasks import set_state
-
-        tasks = [(task, map_index)] if map_index else [task]
+        if map_indexes is None:
+            tasks: Union[List[Operator], List[Tuple[Operator, int]]] = [task]
+        else:
+            tasks = [(task, map_index) for map_index in map_indexes]
 
         to_be_altered = set_state(
             tasks=tasks,
@@ -2408,26 +2415,30 @@ class Airflow(AirflowBaseView):
         args = request.form
         dag_id = args.get('dag_id')
         task_id = args.get('task_id')
-        origin = get_safe_url(args.get('origin'))
-        dag_run_id = args.get('dag_run_id')
-        map_index = args.get('map_index')
+        run_id = args.get('dag_run_id')
 
+        if 'map_index' not in args:
+            map_indexes: Optional[List[int]] = None
+        else:
+            map_indexes = args.getlist('map_index', type=int)
+
+        origin = get_safe_url(args.get('origin'))
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
         future = to_boolean(args.get('future'))
         past = to_boolean(args.get('past'))
 
         return self._mark_task_instance_state(
-            dag_id,
-            task_id,
-            origin,
-            dag_run_id,
-            upstream,
-            downstream,
-            future,
-            past,
-            State.FAILED,
-            map_index=map_index,
+            dag_id=dag_id,
+            run_id=run_id,
+            task_id=task_id,
+            map_indexes=map_indexes,
+            origin=origin,
+            upstream=upstream,
+            downstream=downstream,
+            future=future,
+            past=past,
+            state=TaskInstanceState.FAILED,
         )
 
     @expose('/success', methods=['POST'])
@@ -2443,26 +2454,30 @@ class Airflow(AirflowBaseView):
         args = request.form
         dag_id = args.get('dag_id')
         task_id = args.get('task_id')
-        origin = get_safe_url(args.get('origin'))
-        dag_run_id = args.get('dag_run_id')
-        map_index = args.get('map_index')
+        run_id = args.get('dag_run_id')
+
+        if 'map_index' not in args:
+            map_indexes: Optional[List[int]] = None
+        else:
+            map_indexes = args.getlist('map_index', type=int)
 
+        origin = get_safe_url(args.get('origin'))
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
         future = to_boolean(args.get('future'))
         past = to_boolean(args.get('past'))
 
         return self._mark_task_instance_state(
-            dag_id,
-            task_id,
-            origin,
-            dag_run_id,
-            upstream,
-            downstream,
-            future,
-            past,
-            State.SUCCESS,
-            map_index=map_index,
+            dag_id=dag_id,
+            run_id=run_id,
+            task_id=task_id,
+            map_indexes=map_indexes,
+            origin=origin,
+            upstream=upstream,
+            downstream=downstream,
+            future=future,
+            past=past,
+            state=TaskInstanceState.SUCCESS,
         )
 
     @expose('/dags/<string:dag_id>')
diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py
index f4be0540c3..c7900d64fd 100644
--- a/tests/www/views/test_views.py
+++ b/tests/www/views/test_views.py
@@ -271,9 +271,10 @@ def test_mark_task_instance_state(test_app):
 
         view._mark_task_instance_state(
             dag_id=dag.dag_id,
+            run_id=dagrun.run_id,
             task_id=task_1.task_id,
+            map_indexes=None,
             origin="",
-            dag_run_id=dagrun.run_id,
             upstream=False,
             downstream=False,
             future=False,
diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py
index fce94fd5e4..ebed9ab05f 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -517,7 +517,7 @@ def test_dag_never_run(admin_client, url):
     )
     clear_db_runs()
     resp = admin_client.post(url, data=form, follow_redirects=True)
-    check_content_in_response(f"Cannot mark tasks as {url}, seem that dag {dag_id} has never run", resp)
+    check_content_in_response(f"Cannot mark tasks as {url}, seem that DAG {dag_id} has never run", resp)
 
 
 class _ForceHeartbeatCeleryExecutor(CeleryExecutor):