You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/20 14:31:32 UTC
[airflow] 01/09: Accept multiple map_index param from front end
This is an automated email from the ASF dual-hosted git repository.
bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 4ffae5cf1fa8d460451c875c0940c57d91187b43
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Apr 19 09:16:04 2022 +0800
Accept multiple map_index param from front end
This allows setting multiple instances of the same task to SUCCESS or
FAILED in one request. This is translated to multiple task specifier
tuples (task_id, map_index) when passed to set_state().
Also made some drive-through improvements adding types and clean some
formatting up.
---
airflow/api/common/mark_tasks.py | 4 +-
airflow/models/dag.py | 12 ++---
airflow/www/views.py | 105 ++++++++++++++++++++----------------
tests/www/views/test_views.py | 3 +-
tests/www/views/test_views_tasks.py | 2 +-
5 files changed, 71 insertions(+), 55 deletions(-)
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 594423305c..349b935e82 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -18,7 +18,7 @@
"""Marks tasks APIs."""
from datetime import datetime
-from typing import TYPE_CHECKING, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
from sqlalchemy import or_, tuple_
from sqlalchemy.orm import contains_eager
@@ -78,7 +78,7 @@ def _create_dagruns(
@provide_session
def set_state(
*,
- tasks: Union[Iterable[Operator], Iterable[Tuple[Operator, int]]],
+ tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]],
run_id: Optional[str] = None,
execution_date: Optional[datetime] = None,
upstream: bool = False,
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 755505b5d0..9c93bcef13 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1620,7 +1620,7 @@ class DAG(LoggingMixin):
self,
*,
task_id: str,
- map_index: Optional[int] = None,
+ map_indexes: Optional[Collection[int]] = None,
execution_date: Optional[datetime] = None,
run_id: Optional[str] = None,
state: TaskInstanceState,
@@ -1636,8 +1636,8 @@ class DAG(LoggingMixin):
in failed or upstream_failed state.
:param task_id: Task ID of the TaskInstance
- :param map_index: The TaskInstance map_index, if None, would set state for all mapped
- TaskInstances of the task
+ :param map_indexes: Only set TaskInstance if its map_index matches.
+ If None (default), all mapped TaskInstances of the task are set.
:param execution_date: Execution date of the TaskInstance
:param run_id: The run_id of the TaskInstance
:param state: State to set the TaskInstance to
@@ -1665,12 +1665,12 @@ class DAG(LoggingMixin):
tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]]
task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]]
- if map_index is None:
+ if map_indexes is None:
tasks_to_set_state = [task]
task_ids_to_exclude_from_clear = {task_id}
else:
- tasks_to_set_state = [(task, map_index)]
- task_ids_to_exclude_from_clear = {(task_id, map_index)}
+ tasks_to_set_state = [(task, map_index) for map_index in map_indexes]
+ task_ids_to_exclude_from_clear = {(task_id, map_index) for map_index in map_indexes}
altered = set_state(
tasks=tasks_to_set_state,
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 437a60cca0..aac30f64ff 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -95,6 +95,7 @@ from airflow.api.common.mark_tasks import (
set_dag_run_state_to_failed,
set_dag_run_state_to_queued,
set_dag_run_state_to_success,
+ set_state,
)
from airflow.compat.functools import cached_property
from airflow.configuration import AIRFLOW_CONFIG, conf
@@ -107,6 +108,7 @@ from airflow.models import DAG, Connection, DagModel, DagTag, Log, SlaMiss, Task
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.dagcode import DagCode
from airflow.models.dagrun import DagRun, DagRunType
+from airflow.models.operator import Operator
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
from airflow.providers_manager import ProvidersManager
@@ -2284,28 +2286,28 @@ class Airflow(AirflowBaseView):
def _mark_task_instance_state(
self,
- dag_id,
- task_id,
- origin,
- dag_run_id,
- upstream,
- downstream,
- future,
- past,
- state,
- map_index=None,
+ *,
+ dag_id: str,
+ run_id: str,
+ task_id: str,
+ map_indexes: Optional[List[int]],
+ origin: str,
+ upstream: bool,
+ downstream: bool,
+ future: bool,
+ past: bool,
+ state: TaskInstanceState,
):
dag = current_app.dag_bag.get_dag(dag_id)
- latest_execution_date = dag.get_latest_execution_date()
- if not latest_execution_date:
- flash(f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run", "error")
+ if not run_id:
+ flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has never run", "error")
return redirect(origin)
altered = dag.set_task_instance_state(
task_id=task_id,
- map_index=map_index,
- run_id=dag_run_id,
+ map_indexes=map_indexes,
+ run_id=run_id,
state=state,
upstream=upstream,
downstream=downstream,
@@ -2332,7 +2334,11 @@ class Airflow(AirflowBaseView):
dag_run_id = args.get('dag_run_id')
state = args.get('state')
origin = args.get('origin')
- map_index = args.get('map_index')
+
+ if 'map_index' not in args:
+ map_indexes: Optional[List[int]] = None
+ else:
+ map_indexes = args.getlist('map_index', type=int)
upstream = to_boolean(args.get('upstream'))
downstream = to_boolean(args.get('downstream'))
@@ -2365,9 +2371,10 @@ class Airflow(AirflowBaseView):
msg = f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run"
return redirect_or_json(origin, msg, status='error')
- from airflow.api.common.mark_tasks import set_state
-
- tasks = [(task, map_index)] if map_index else [task]
+ if map_indexes is None:
+ tasks: Union[List[Operator], List[Tuple[Operator, int]]] = [task]
+ else:
+ tasks = [(task, map_index) for map_index in map_indexes]
to_be_altered = set_state(
tasks=tasks,
@@ -2408,26 +2415,30 @@ class Airflow(AirflowBaseView):
args = request.form
dag_id = args.get('dag_id')
task_id = args.get('task_id')
- origin = get_safe_url(args.get('origin'))
- dag_run_id = args.get('dag_run_id')
- map_index = args.get('map_index')
+ run_id = args.get('dag_run_id')
+ if 'map_index' not in args:
+ map_indexes: Optional[List[int]] = None
+ else:
+ map_indexes = args.getlist('map_index', type=int)
+
+ origin = get_safe_url(args.get('origin'))
upstream = to_boolean(args.get('upstream'))
downstream = to_boolean(args.get('downstream'))
future = to_boolean(args.get('future'))
past = to_boolean(args.get('past'))
return self._mark_task_instance_state(
- dag_id,
- task_id,
- origin,
- dag_run_id,
- upstream,
- downstream,
- future,
- past,
- State.FAILED,
- map_index=map_index,
+ dag_id=dag_id,
+ run_id=run_id,
+ task_id=task_id,
+ map_indexes=map_indexes,
+ origin=origin,
+ upstream=upstream,
+ downstream=downstream,
+ future=future,
+ past=past,
+ state=TaskInstanceState.FAILED,
)
@expose('/success', methods=['POST'])
@@ -2443,26 +2454,30 @@ class Airflow(AirflowBaseView):
args = request.form
dag_id = args.get('dag_id')
task_id = args.get('task_id')
- origin = get_safe_url(args.get('origin'))
- dag_run_id = args.get('dag_run_id')
- map_index = args.get('map_index')
+ run_id = args.get('dag_run_id')
+
+ if 'map_index' not in args:
+ map_indexes: Optional[List[int]] = None
+ else:
+ map_indexes = args.getlist('map_index', type=int)
+ origin = get_safe_url(args.get('origin'))
upstream = to_boolean(args.get('upstream'))
downstream = to_boolean(args.get('downstream'))
future = to_boolean(args.get('future'))
past = to_boolean(args.get('past'))
return self._mark_task_instance_state(
- dag_id,
- task_id,
- origin,
- dag_run_id,
- upstream,
- downstream,
- future,
- past,
- State.SUCCESS,
- map_index=map_index,
+ dag_id=dag_id,
+ run_id=run_id,
+ task_id=task_id,
+ map_indexes=map_indexes,
+ origin=origin,
+ upstream=upstream,
+ downstream=downstream,
+ future=future,
+ past=past,
+ state=TaskInstanceState.SUCCESS,
)
@expose('/dags/<string:dag_id>')
diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py
index f4be0540c3..c7900d64fd 100644
--- a/tests/www/views/test_views.py
+++ b/tests/www/views/test_views.py
@@ -271,9 +271,10 @@ def test_mark_task_instance_state(test_app):
view._mark_task_instance_state(
dag_id=dag.dag_id,
+ run_id=dagrun.run_id,
task_id=task_1.task_id,
+ map_indexes=None,
origin="",
- dag_run_id=dagrun.run_id,
upstream=False,
downstream=False,
future=False,
diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py
index fce94fd5e4..ebed9ab05f 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -517,7 +517,7 @@ def test_dag_never_run(admin_client, url):
)
clear_db_runs()
resp = admin_client.post(url, data=form, follow_redirects=True)
- check_content_in_response(f"Cannot mark tasks as {url}, seem that dag {dag_id} has never run", resp)
+ check_content_in_response(f"Cannot mark tasks as {url}, seem that DAG {dag_id} has never run", resp)
class _ForceHeartbeatCeleryExecutor(CeleryExecutor):