You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/20 19:00:42 UTC

[airflow] branch mapped-instance-actions updated (009eb4d4fe -> 1eb9b5cbdc)

This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a change to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git


 discard 009eb4d4fe Get gantt/graph modal actions working again
 discard 7fcedbd7c7 Chain map_index params
 discard 4410c01192 Fix gantt/graph modal
 discard 4ffa59eaf4 Readd mapped instance table selection
 discard c633ca8895 Allow bulk mapped task actions
 discard a890173f5a fixup! Introduce tuple_().in_() shim for MSSQL compat
 discard 6c7f1dfb16 fixup! Accept multiple map_index param from front end
 discard 898480765f Introduce tuple_().in_() shim for MSSQL compat
 discard 4ffae5cf1f Accept multiple map_index param from front end
 discard e1cdc961b3 Refactor to straighten up types
 discard 1c0741234e fixup! Apply suggestions from code review
 discard 1ccc3f6131 Apply suggestions from code review
 discard e65e5fcf7f fixup! fixup! fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
 discard 4186bb220d fixup! fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
 discard df4cfb4524 fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
 discard 6ce2589991 add tests
 discard 59b1769171 fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
 discard 6f16553973 fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
 discard 7837df0994 fixup! Allow marking/clearing mapped taskinstances from the UI
 discard 1f03977735 Allow marking/clearing mapped taskinstances from the UI
     add 9e1ac6e425 Add `S3CreateObjectOperator` (#22758)
     add 501a3c3fbe Meaningful error mssage in resolve_template_files (#23027)
     add 70eede5dd6 Fix KPO to have hyphen instead of period (#22982)
     add 5b9bd9954b Replace changelog/updating with release notes and towncrier now (#22003)
     add 99cac42df0 Make copy button blue (#23120)
     add eb26510d3a Switch bitnami images in tests to "standard" ones (#23122)
     add 4fa718e4db Support clearing and updating state of individual mapped task instances (#22958)
     new 5620df9469 fixup! Allow marking/clearing mapped taskinstances from the UI
     new 334e39923e fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
     new 83413da968 fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
     new bcf186d9eb add tests
     new c5cc48f9f9 fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
     new 7ea4ace57d fixup! fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
     new 5bbcb9f17f fixup! fixup! fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
     new 9d5d8b28cb Apply suggestions from code review
     new 4e81a78030 fixup! Apply suggestions from code review
     new 3b12a914db Refactor to straighten up types
     new 03676ff070 Accept multiple map_index param from front end
     new 469092494d Introduce tuple_().in_() shim for MSSQL compat
     new 3327e0dc10 fixup! Accept multiple map_index param from front end
     new 3d25f2ebba fixup! Introduce tuple_().in_() shim for MSSQL compat
     new b77a920335 Allow bulk mapped task actions
     new acb534127b Readd mapped instance table selection
     new 6b7b8026b1 Fix gantt/graph modal
     new f489b107ec Chain map_index params
     new 1eb9b5cbdc Get gantt/graph modal actions working again

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (009eb4d4fe)
            \
             N -- N -- N   refs/heads/mapped-instance-actions (1eb9b5cbdc)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 19 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .dockerignore                                      |    2 +-
 .github/PULL_REQUEST_TEMPLATE.md                   |    2 +-
 .pre-commit-config.yaml                            |   10 +-
 .rat-excludes                                      |    1 -
 CONTRIBUTING.rst                                   |   19 +
 MANIFEST.in                                        |    2 +-
 CHANGELOG.txt => RELEASE_NOTES.rst                 | 4632 +++++++++++++++++++-
 STATIC_CODE_CHECKS.rst                             |    2 +
 UPDATING.md                                        | 3665 ----------------
 airflow/config_templates/config.yml                |    2 +-
 airflow/config_templates/default_airflow.cfg       |    2 +-
 airflow/models/abstractoperator.py                 |    4 +-
 .../amazon/aws/example_dags/example_s3.py          |   21 +
 airflow/providers/amazon/aws/operators/s3.py       |   88 +
 .../cncf/kubernetes/operators/kubernetes_pod.py    |    2 +-
 airflow/www/static/js/tree/Clipboard.jsx           |    4 +-
 breeze-complete                                    |    1 +
 chart/CHANGELOG.txt                                |  246 --
 chart/RELEASE_NOTES.rst                            |  412 ++
 chart/UPDATING.rst                                 |  172 -
 chart/newsfragments/22724.significant.rst          |    3 +
 chart/newsfragments/config.toml                    |   34 +
 dev/README_RELEASE_AIRFLOW.md                      |   29 +-
 dev/README_RELEASE_HELM_CHART.md                   |   27 +-
 dev/breeze/src/airflow_breeze/pre_commit_ids.py    |    1 +
 .../airflow_breeze/utils/docker_command_utils.py   |    2 +-
 .../operators/s3.rst                               |   15 +
 docs/apache-airflow/index.rst                      |    2 +-
 docs/apache-airflow/redirects.txt                  |    3 +
 .../release_notes.rst}                             |    6 +-
 docs/helm-chart/index.rst                          |    3 +-
 MANIFEST.in => docs/helm-chart/redirects.txt       |   26 +-
 .../changelog.rst => helm-chart/release_notes.rst} |    6 +-
 docs/helm-chart/updating.rst                       |   18 -
 docs/spelling_wordlist.txt                         |    8 +
 images/breeze/output-static-checks.svg             |   19 +-
 newsfragments/16931.improvement.rst                |    1 +
 newsfragments/17349.feature.rst                    |    1 +
 newsfragments/19482.improvement.rst                |    1 +
 newsfragments/19825.significant.rst                |    3 +
 newsfragments/20165.significant.rst                |    5 +
 newsfragments/20759.significant.rst                |    3 +
 newsfragments/20975.significant.rst                |    5 +
 newsfragments/21135.significant.rst                |   16 +
 newsfragments/21205.significant.rst                |   16 +
 newsfragments/21472.significant.rst                |    5 +
 newsfragments/21505.significant.rst                |    3 +
 newsfragments/21538.significant.rst                |    3 +
 newsfragments/21640.significant.rst                |    3 +
 newsfragments/21734.significant.rst                |   13 +
 newsfragments/21798.significant.rst                |    9 +
 newsfragments/21815.significant.2.rst              |    3 +
 newsfragments/21815.significant.rst                |    3 +
 newsfragments/21816.significant.rst                |   43 +
 newsfragments/22167.significant.rst                |    3 +
 newsfragments/22284.significant.rst                |   16 +
 newsfragments/config.toml                          |   34 +
 scripts/ci/docker-compose/local.yml                |    2 +-
 scripts/ci/libraries/_local_mounts.sh              |    2 +-
 scripts/ci/pre_commit/pre_commit_newsfragments.py  |   52 +
 setup.py                                           |    1 +
 .../amazon/aws/operators/test_s3_copy_object.py    |   83 -
 ...test_s3_delete_objects.py => test_s3_object.py} |  101 +-
 tests/providers/docker/decorators/test_docker.py   |   10 +-
 64 files changed, 5448 insertions(+), 4483 deletions(-)
 rename CHANGELOG.txt => RELEASE_NOTES.rst (60%)
 delete mode 100644 UPDATING.md
 delete mode 100644 chart/CHANGELOG.txt
 create mode 100644 chart/RELEASE_NOTES.rst
 delete mode 100644 chart/UPDATING.rst
 create mode 100644 chart/newsfragments/22724.significant.rst
 create mode 100644 chart/newsfragments/config.toml
 rename docs/{helm-chart/changelog.rst => apache-airflow/release_notes.rst} (92%)
 copy MANIFEST.in => docs/helm-chart/redirects.txt (53%)
 rename docs/{apache-airflow/changelog.rst => helm-chart/release_notes.rst} (91%)
 delete mode 100644 docs/helm-chart/updating.rst
 create mode 100644 newsfragments/16931.improvement.rst
 create mode 100644 newsfragments/17349.feature.rst
 create mode 100644 newsfragments/19482.improvement.rst
 create mode 100644 newsfragments/19825.significant.rst
 create mode 100644 newsfragments/20165.significant.rst
 create mode 100644 newsfragments/20759.significant.rst
 create mode 100644 newsfragments/20975.significant.rst
 create mode 100644 newsfragments/21135.significant.rst
 create mode 100644 newsfragments/21205.significant.rst
 create mode 100644 newsfragments/21472.significant.rst
 create mode 100644 newsfragments/21505.significant.rst
 create mode 100644 newsfragments/21538.significant.rst
 create mode 100644 newsfragments/21640.significant.rst
 create mode 100644 newsfragments/21734.significant.rst
 create mode 100644 newsfragments/21798.significant.rst
 create mode 100644 newsfragments/21815.significant.2.rst
 create mode 100644 newsfragments/21815.significant.rst
 create mode 100644 newsfragments/21816.significant.rst
 create mode 100644 newsfragments/22167.significant.rst
 create mode 100644 newsfragments/22284.significant.rst
 create mode 100644 newsfragments/config.toml
 create mode 100755 scripts/ci/pre_commit/pre_commit_newsfragments.py
 delete mode 100644 tests/providers/amazon/aws/operators/test_s3_copy_object.py
 rename tests/providers/amazon/aws/operators/{test_s3_delete_objects.py => test_s3_object.py} (67%)


[airflow] 13/19: fixup! Accept multiple map_index param from front end

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3327e0dc10fcb50c2bf112433f257bfd66fef62a
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Wed Apr 20 10:35:00 2022 +0800

    fixup! Accept multiple map_index param from front end
---
 airflow/www/views.py | 44 ++++++++++++++++++++++++--------------------
 1 file changed, 24 insertions(+), 20 deletions(-)

diff --git a/airflow/www/views.py b/airflow/www/views.py
index aac30f64ff..ae0186e493 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1960,7 +1960,7 @@ class Airflow(AirflowBaseView):
 
     def _clear_dag_tis(
         self,
-        dag,
+        dag: DAG,
         start_date,
         end_date,
         origin,
@@ -1995,24 +1995,19 @@ class Airflow(AirflowBaseView):
         except AirflowException as ex:
             return redirect_or_json(origin, msg=str(ex), status="error")
 
-        if not tis:
-            msg = "No task instances to clear"
-            return redirect_or_json(origin, msg, status="error")
-        elif request.headers.get('Accept') == 'application/json':
-            details = [str(t) for t in tis]
+        assert isinstance(tis, collections.abc.Iterable)
+        details = [str(t) for t in tis]
 
+        if not details:
+            return redirect_or_json(origin, "No task instances to clear", status="error")
+        elif request.headers.get('Accept') == 'application/json':
             return htmlsafe_json_dumps(details, separators=(',', ':'))
-        else:
-            details = "\n".join(str(t) for t in tis)
-
-            response = self.render_template(
-                'airflow/confirm.html',
-                endpoint=None,
-                message="Task instances you are about to clear:",
-                details=details,
-            )
-
-        return response
+        return self.render_template(
+            'airflow/confirm.html',
+            endpoint=None,
+            message="Task instances you are about to clear:",
+            details="\n".join(details),
+        )
 
     @expose('/clear', methods=['POST'])
     @auth.has_access(
@@ -2028,7 +2023,11 @@ class Airflow(AirflowBaseView):
         task_id = request.form.get('task_id')
         origin = get_safe_url(request.form.get('origin'))
         dag = current_app.dag_bag.get_dag(dag_id)
-        map_index = request.form.get('map_index')
+
+        if 'map_index' not in request.form:
+            map_indexes: Optional[List[int]] = None
+        else:
+            map_indexes = request.form.getlist('map_index', type=int)
 
         execution_date = request.form.get('execution_date')
         execution_date = timezone.parse(execution_date)
@@ -2047,7 +2046,12 @@ class Airflow(AirflowBaseView):
         )
         end_date = execution_date if not future else None
         start_date = execution_date if not past else None
-        task_ids = [(task_id, map_index)] if map_index else [task_id]
+
+        if map_indexes is None:
+            task_ids: Union[List[str], List[Tuple[str, int]]] = [task_id]
+        else:
+            task_ids = [(task_id, map_index) for map_index in map_indexes]
+
         return self._clear_dag_tis(
             dag,
             start_date,
@@ -2298,7 +2302,7 @@ class Airflow(AirflowBaseView):
         past: bool,
         state: TaskInstanceState,
     ):
-        dag = current_app.dag_bag.get_dag(dag_id)
+        dag: DAG = current_app.dag_bag.get_dag(dag_id)
 
         if not run_id:
             flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has never run", "error")


[airflow] 08/19: Apply suggestions from code review

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 9d5d8b28cb14ac6b780acb08f413f49ba0a9d93b
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Fri Apr 15 09:20:02 2022 +0100

    Apply suggestions from code review
    
    Co-authored-by: Jed Cunningham <66...@users.noreply.github.com>
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
 airflow/api/common/mark_tasks.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 8885ebb59e..a9e4f4812e 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -96,7 +96,7 @@ def set_state(
     tasks that did not exist. It will not create dag runs that are missing
     on the schedule (but it will as for subdag dag runs if needed).
 
-    :param tasks: the iterable of tasks or task, map_index tuple from which to work.
+    :param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
         task.task.dag needs to be set
     :param run_id: the run_id of the dagrun to start looking from
     :param execution_date: the execution date from which to start looking(deprecated)
@@ -119,9 +119,10 @@ def set_state(
     if execution_date and not timezone.is_localized(execution_date):
         raise ValueError(f"Received non-localized date {execution_date}")
 
-    t_dags = {task.dag for task in tasks if not isinstance(task, tuple)}
-    t_dags_2 = {item[0].dag for item in tasks if isinstance(item, tuple)}
-    task_dags = t_dags | t_dags_2
+    task_dags = {
+        task[0].dag if isinstance(task, tuple) else task.dag
+        for task in tasks
+    }
     if len(task_dags) > 1:
         raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
     dag = next(iter(task_dags))


[airflow] 11/19: Accept multiple map_index param from front end

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 03676ff07035340e10e8e19ead1be7c9a84ae3b3
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Apr 19 09:16:04 2022 +0800

    Accept multiple map_index param from front end
    
    This allows setting multiple instances of the same task to SUCCESS or
    FAILED in one request. This is translated to multiple task specifier
    tuples (task_id, map_index) when passed to set_state().
    
    Also made some drive-through improvements adding types and clean some
    formatting up.
---
 airflow/api/common/mark_tasks.py |   4 +-
 airflow/models/dag.py            |  12 ++---
 airflow/www/views.py             | 105 ++++++++++++++++++++++-----------------
 3 files changed, 68 insertions(+), 53 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 594423305c..349b935e82 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -18,7 +18,7 @@
 """Marks tasks APIs."""
 
 from datetime import datetime
-from typing import TYPE_CHECKING, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
 
 from sqlalchemy import or_, tuple_
 from sqlalchemy.orm import contains_eager
@@ -78,7 +78,7 @@ def _create_dagruns(
 @provide_session
 def set_state(
     *,
-    tasks: Union[Iterable[Operator], Iterable[Tuple[Operator, int]]],
+    tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]],
     run_id: Optional[str] = None,
     execution_date: Optional[datetime] = None,
     upstream: bool = False,
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 755505b5d0..9c93bcef13 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1620,7 +1620,7 @@ class DAG(LoggingMixin):
         self,
         *,
         task_id: str,
-        map_index: Optional[int] = None,
+        map_indexes: Optional[Collection[int]] = None,
         execution_date: Optional[datetime] = None,
         run_id: Optional[str] = None,
         state: TaskInstanceState,
@@ -1636,8 +1636,8 @@ class DAG(LoggingMixin):
         in failed or upstream_failed state.
 
         :param task_id: Task ID of the TaskInstance
-        :param map_index: The TaskInstance map_index, if None, would set state for all mapped
-            TaskInstances of the task
+        :param map_indexes: Only set TaskInstance if its map_index matches.
+            If None (default), all mapped TaskInstances of the task are set.
         :param execution_date: Execution date of the TaskInstance
         :param run_id: The run_id of the TaskInstance
         :param state: State to set the TaskInstance to
@@ -1665,12 +1665,12 @@ class DAG(LoggingMixin):
 
         tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]]
         task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]]
-        if map_index is None:
+        if map_indexes is None:
             tasks_to_set_state = [task]
             task_ids_to_exclude_from_clear = {task_id}
         else:
-            tasks_to_set_state = [(task, map_index)]
-            task_ids_to_exclude_from_clear = {(task_id, map_index)}
+            tasks_to_set_state = [(task, map_index) for map_index in map_indexes]
+            task_ids_to_exclude_from_clear = {(task_id, map_index) for map_index in map_indexes}
 
         altered = set_state(
             tasks=tasks_to_set_state,
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 437a60cca0..aac30f64ff 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -95,6 +95,7 @@ from airflow.api.common.mark_tasks import (
     set_dag_run_state_to_failed,
     set_dag_run_state_to_queued,
     set_dag_run_state_to_success,
+    set_state,
 )
 from airflow.compat.functools import cached_property
 from airflow.configuration import AIRFLOW_CONFIG, conf
@@ -107,6 +108,7 @@ from airflow.models import DAG, Connection, DagModel, DagTag, Log, SlaMiss, Task
 from airflow.models.abstractoperator import AbstractOperator
 from airflow.models.dagcode import DagCode
 from airflow.models.dagrun import DagRun, DagRunType
+from airflow.models.operator import Operator
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskinstance import TaskInstance
 from airflow.providers_manager import ProvidersManager
@@ -2284,28 +2286,28 @@ class Airflow(AirflowBaseView):
 
     def _mark_task_instance_state(
         self,
-        dag_id,
-        task_id,
-        origin,
-        dag_run_id,
-        upstream,
-        downstream,
-        future,
-        past,
-        state,
-        map_index=None,
+        *,
+        dag_id: str,
+        run_id: str,
+        task_id: str,
+        map_indexes: Optional[List[int]],
+        origin: str,
+        upstream: bool,
+        downstream: bool,
+        future: bool,
+        past: bool,
+        state: TaskInstanceState,
     ):
         dag = current_app.dag_bag.get_dag(dag_id)
-        latest_execution_date = dag.get_latest_execution_date()
 
-        if not latest_execution_date:
-            flash(f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run", "error")
+        if not run_id:
+            flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has never run", "error")
             return redirect(origin)
 
         altered = dag.set_task_instance_state(
             task_id=task_id,
-            map_index=map_index,
-            run_id=dag_run_id,
+            map_indexes=map_indexes,
+            run_id=run_id,
             state=state,
             upstream=upstream,
             downstream=downstream,
@@ -2332,7 +2334,11 @@ class Airflow(AirflowBaseView):
         dag_run_id = args.get('dag_run_id')
         state = args.get('state')
         origin = args.get('origin')
-        map_index = args.get('map_index')
+
+        if 'map_index' not in args:
+            map_indexes: Optional[List[int]] = None
+        else:
+            map_indexes = args.getlist('map_index', type=int)
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2365,9 +2371,10 @@ class Airflow(AirflowBaseView):
             msg = f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run"
             return redirect_or_json(origin, msg, status='error')
 
-        from airflow.api.common.mark_tasks import set_state
-
-        tasks = [(task, map_index)] if map_index else [task]
+        if map_indexes is None:
+            tasks: Union[List[Operator], List[Tuple[Operator, int]]] = [task]
+        else:
+            tasks = [(task, map_index) for map_index in map_indexes]
 
         to_be_altered = set_state(
             tasks=tasks,
@@ -2408,26 +2415,30 @@ class Airflow(AirflowBaseView):
         args = request.form
         dag_id = args.get('dag_id')
         task_id = args.get('task_id')
-        origin = get_safe_url(args.get('origin'))
-        dag_run_id = args.get('dag_run_id')
-        map_index = args.get('map_index')
+        run_id = args.get('dag_run_id')
 
+        if 'map_index' not in args:
+            map_indexes: Optional[List[int]] = None
+        else:
+            map_indexes = args.getlist('map_index', type=int)
+
+        origin = get_safe_url(args.get('origin'))
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
         future = to_boolean(args.get('future'))
         past = to_boolean(args.get('past'))
 
         return self._mark_task_instance_state(
-            dag_id,
-            task_id,
-            origin,
-            dag_run_id,
-            upstream,
-            downstream,
-            future,
-            past,
-            State.FAILED,
-            map_index=map_index,
+            dag_id=dag_id,
+            run_id=run_id,
+            task_id=task_id,
+            map_indexes=map_indexes,
+            origin=origin,
+            upstream=upstream,
+            downstream=downstream,
+            future=future,
+            past=past,
+            state=TaskInstanceState.FAILED,
         )
 
     @expose('/success', methods=['POST'])
@@ -2443,26 +2454,30 @@ class Airflow(AirflowBaseView):
         args = request.form
         dag_id = args.get('dag_id')
         task_id = args.get('task_id')
-        origin = get_safe_url(args.get('origin'))
-        dag_run_id = args.get('dag_run_id')
-        map_index = args.get('map_index')
+        run_id = args.get('dag_run_id')
+
+        if 'map_index' not in args:
+            map_indexes: Optional[List[int]] = None
+        else:
+            map_indexes = args.getlist('map_index', type=int)
 
+        origin = get_safe_url(args.get('origin'))
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
         future = to_boolean(args.get('future'))
         past = to_boolean(args.get('past'))
 
         return self._mark_task_instance_state(
-            dag_id,
-            task_id,
-            origin,
-            dag_run_id,
-            upstream,
-            downstream,
-            future,
-            past,
-            State.SUCCESS,
-            map_index=map_index,
+            dag_id=dag_id,
+            run_id=run_id,
+            task_id=task_id,
+            map_indexes=map_indexes,
+            origin=origin,
+            upstream=upstream,
+            downstream=downstream,
+            future=future,
+            past=past,
+            state=TaskInstanceState.SUCCESS,
         )
 
     @expose('/dags/<string:dag_id>')


[airflow] 04/19: add tests

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit bcf186d9ebd5202b0d26aec244241da9d6482622
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Apr 12 20:27:39 2022 +0100

    add tests
---
 tests/api/common/test_mark_tasks.py | 6 +++---
 tests/models/test_dag.py            | 2 +-
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/tests/api/common/test_mark_tasks.py b/tests/api/common/test_mark_tasks.py
index 4c1d3c604b..dedb624b1f 100644
--- a/tests/api/common/test_mark_tasks.py
+++ b/tests/api/common/test_mark_tasks.py
@@ -439,12 +439,12 @@ class TestMarkTasks:
     def test_mark_mapped_task_instance_state(self):
         # set mapped task instance to success
         snapshot = TestMarkTasks.snapshot_state(self.dag4, self.execution_dates)
-        task = self.dag4.get_task("consumer_literal")
-        tasks = [(task, 0), (task, 1)]
+        tasks = [self.dag4.get_task("consumer_literal")]
         map_indexes = [0, 1]
         dr = DagRun.find(dag_id=self.dag4.dag_id, execution_date=self.execution_dates[0])[0]
         altered = set_state(
             tasks=tasks,
+            map_indexes=map_indexes,
             run_id=dr.run_id,
             upstream=False,
             downstream=False,
@@ -456,7 +456,7 @@ class TestMarkTasks:
         assert len(altered) == 2
         self.verify_state(
             self.dag4,
-            [task.task_id for task, _ in tasks],
+            [task.task_id for task in tasks],
             [self.execution_dates[0]],
             State.SUCCESS,
             snapshot,
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 6cd8ea660f..41219ef87b 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1456,7 +1456,7 @@ class TestDag(unittest.TestCase):
         session.flush()
 
         dag.clear(
-            task_ids=[(task_id, 0)],
+            map_indexes=[0],
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=1),
             dag_run_state=dag_run_state,


[airflow] 19/19: Get gantt/graph modal actions working again

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 1eb9b5cbdcb444e355e65b3027cbdaf70d639179
Author: Brent Bovenzi <br...@gmail.com>
AuthorDate: Wed Apr 20 10:29:06 2022 -0400

    Get gantt/graph modal actions working again
---
 airflow/www/static/js/dag.js           | 9 ++++-----
 airflow/www/templates/airflow/dag.html | 6 +++---
 airflow/www/utils.py                   | 1 +
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/airflow/www/static/js/dag.js b/airflow/www/static/js/dag.js
index ded26baeab..0445e0686c 100644
--- a/airflow/www/static/js/dag.js
+++ b/airflow/www/static/js/dag.js
@@ -347,11 +347,10 @@ $('form[data-action]').on('submit', function submit(e) {
     if (form.task_id) {
       form.task_id.value = taskId;
     }
-    if (form.map_index) {
-      form.map_index.value = mapIndex === undefined ? '' : mapIndex;
-    }
-    if (form.map_indexes) {
-      form.map_indexes.value = mapIndex === undefined ? '' : mapIndex;
+    if (form.map_index && mapIndex >= 0) {
+      form.map_index.value = mapIndex;
+    } else if (form.map_index) {
+      form.map_index.remove();
     }
     form.action = $(this).data('action');
     form.submit();
diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html
index 8259c7045c..e7afde8af9 100644
--- a/airflow/www/templates/airflow/dag.html
+++ b/airflow/www/templates/airflow/dag.html
@@ -310,7 +310,7 @@
             <input type="hidden" name="dag_id" value="{{ dag.dag_id }}">
             <input type="hidden" name="task_id">
             <input type="hidden" name="execution_date">
-            <input type="hidden" name="map_indexes">
+            <input type="hidden" name="map_index">
             <input type="hidden" name="origin" value="{{ request.base_url }}">
             <div class="row">
               <span class="btn-group col-xs-12 col-sm-9 task-instance-modal-column" data-toggle="buttons">
@@ -352,7 +352,7 @@
             <input type="hidden" name="dag_id" value="{{ dag.dag_id }}">
             <input type="hidden" name="task_id">
             <input type="hidden" name="dag_run_id">
-            <input type="hidden" name="map_indexes">
+            <input type="hidden" name="map_index">
             <input type="hidden" name="origin" value="{{ request.base_url }}">
             <input type="hidden" name="state" value="failed">
             <div class="row">
@@ -386,7 +386,7 @@
             <input type="hidden" name="dag_id" value="{{ dag.dag_id }}">
             <input type="hidden" name="task_id">
             <input type="hidden" name="dag_run_id">
-            <input type="hidden" name="map_indexes">
+            <input type="hidden" name="map_index">
             <input type="hidden" name="origin" value="{{ request.base_url }}">
             <input type="hidden" name="state" value="success">
             <div class="row">
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 8a5cc3c717..0c452dea38 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -64,6 +64,7 @@ def get_mapped_instances(task_instance, session):
             TaskInstance.run_id == task_instance.run_id,
             TaskInstance.task_id == task_instance.task_id,
         )
+        .order_by(TaskInstance.map_index)
         .all()
     )
 


[airflow] 05/19: fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit c5cc48f9f9161b187368d99b551c652e1da03de5
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Apr 13 20:16:56 2022 +0100

    fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
---
 airflow/api/common/mark_tasks.py    |  38 ++++-----
 airflow/models/dag.py               | 159 ++++++++++++++++++++----------------
 airflow/www/views.py                |  44 ++++++----
 tests/api/common/test_mark_tasks.py |   6 +-
 tests/models/test_dag.py            |   2 +-
 5 files changed, 133 insertions(+), 116 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 1d4709fb82..84fd48f4e4 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -18,9 +18,9 @@
 """Marks tasks APIs."""
 
 from datetime import datetime
-from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
 
-from sqlalchemy import or_
+from sqlalchemy import or_, tuple_
 from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm.session import Session as SASession
 
@@ -32,7 +32,6 @@ from airflow.operators.subdag import SubDagOperator
 from airflow.utils import timezone
 from airflow.utils.helpers import exactly_one
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
@@ -79,7 +78,7 @@ def _create_dagruns(
 @provide_session
 def set_state(
     *,
-    tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]],
+    tasks: Union[Iterable[Operator], Iterable[Tuple[Operator, int]]],
     run_id: Optional[str] = None,
     execution_date: Optional[datetime] = None,
     upstream: bool = False,
@@ -97,7 +96,7 @@ def set_state(
     tasks that did not exist. It will not create dag runs that are missing
     on the schedule (but it will as for subdag dag runs if needed).
 
-    :param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
+    :param tasks: the iterable of tasks or task, map_index tuple from which to work.
         task.task.dag needs to be set
     :param run_id: the run_id of the dagrun to start looking from
     :param execution_date: the execution date from which to start looking(deprecated)
@@ -120,7 +119,9 @@ def set_state(
     if execution_date and not timezone.is_localized(execution_date):
         raise ValueError(f"Received non-localized date {execution_date}")
 
-    task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks}
+    t_dags = {task.dag for task in tasks if not isinstance(task, tuple)}
+    t_dags_2 = {item[0].dag for item in tasks if isinstance(item, tuple)}
+    task_dags = t_dags | t_dags_2
     if len(task_dags) > 1:
         raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
     dag = next(iter(task_dags))
@@ -135,12 +136,6 @@ def set_state(
     dag_run_ids = get_run_ids(dag, run_id, future, past)
     task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
     task_ids = [task_id for task_id, _ in task_id_map_index_list]
-    # check if task_id_map_index_list contains map_index of None
-    # if it contains None, there was no map_index supplied for the task
-    for _, index in task_id_map_index_list:
-        if index is None:
-            task_id_map_index_list = [task_id for task_id, _ in task_id_map_index_list]
-            break
 
     confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
     confirmed_dates = [info.logical_date for info in confirmed_infos]
@@ -187,26 +182,20 @@ def get_all_dag_task_query(
     dag: DAG,
     session: SASession,
     state: TaskInstanceState,
-    task_ids: Union[List[str], List[Tuple[str, int]]],
+    task_id_map_index_list: List[Tuple[str, int]],
     confirmed_dates: Iterable[datetime],
 ):
     """Get all tasks of the main dag that will be affected by a state change"""
-    is_string_list = isinstance(task_ids[0], str)
     qry_dag = (
         session.query(TaskInstance)
         .join(TaskInstance.dag_run)
         .filter(
             TaskInstance.dag_id == dag.dag_id,
             DagRun.execution_date.in_(confirmed_dates),
+            tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_id_map_index_list),
         )
-    )
-
-    if is_string_list:
-        qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids))
-    else:
-        qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids))
-    qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
-        contains_eager(TaskInstance.dag_run)
+        .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
+        .options(contains_eager(TaskInstance.dag_run))
     )
     return qry_dag
 
@@ -282,13 +271,14 @@ def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str]) -> Iterator[_DagR
         yield _DagRunInfo(dag_run.logical_date, dag.get_run_data_interval(dag_run))
 
 
-def find_task_relatives(tasks, downstream, upstream):
+@provide_session
+def find_task_relatives(tasks, downstream, upstream, session: SASession = NEW_SESSION):
     """Yield task ids and optionally ancestor and descendant ids."""
     for item in tasks:
         if isinstance(item, tuple):
             task, map_index = item
         else:
-            task, map_index = item, None
+            task, map_index = item, -1
         yield task.task_id, map_index
         if downstream:
             for relative in task.get_flat_relatives(upstream=False):
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index e9c33acb72..931fd469d7 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -39,7 +39,6 @@ from typing import (
     Iterable,
     List,
     Optional,
-    Sequence,
     Set,
     Tuple,
     Type,
@@ -52,7 +51,7 @@ import jinja2
 import pendulum
 from dateutil.relativedelta import relativedelta
 from pendulum.tz.timezone import Timezone
-from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_
+from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_, tuple_
 from sqlalchemy.orm import backref, joinedload, relationship
 from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.session import Session
@@ -85,7 +84,7 @@ from airflow.utils.file import correct_maybe_zipped
 from airflow.utils.helpers import exactly_one, validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
 
@@ -1341,33 +1340,47 @@ class DAG(LoggingMixin):
             start_date = (timezone.utcnow() - timedelta(30)).replace(
                 hour=0, minute=0, second=0, microsecond=0
             )
-        query = self._get_task_instances(
-            task_ids=None,
-            start_date=start_date,
-            end_date=end_date,
-            run_id=None,
-            state=state or (),
-            include_subdags=False,
-            include_parentdag=False,
-            include_dependent_dags=False,
-            exclude_task_ids=(),
-            session=session,
+
+        if state is None:
+            state = []
+
+        return (
+            cast(
+                Query,
+                self._get_task_instances(
+                    task_ids=None,
+                    task_ids_and_map_indexes=None,
+                    start_date=start_date,
+                    end_date=end_date,
+                    run_id=None,
+                    state=state,
+                    include_subdags=False,
+                    include_parentdag=False,
+                    include_dependent_dags=False,
+                    exclude_task_ids=cast(List[str], []),
+                    exclude_task_ids_and_map_indexes=None,
+                    session=session,
+                ),
+            )
+            .order_by(DagRun.execution_date)
+            .all()
         )
-        return cast(Query, query).order_by(DagRun.execution_date).all()
 
     @overload
     def _get_task_instances(
         self,
         *,
-        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        task_ids: Iterable[str],
+        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+        state: Union[TaskInstanceState, List[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        exclude_task_ids: Collection[str],
+        exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
         session: Session,
         dag_bag: Optional["DagBag"] = ...,
     ) -> Iterable[TaskInstance]:
@@ -1377,16 +1390,18 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        task_ids: Iterable[str],
+        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
         as_pk_tuple: Literal[True],
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+        state: Union[TaskInstanceState, List[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        exclude_task_ids: Collection[str],
+        exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
         session: Session,
         dag_bag: Optional["DagBag"] = ...,
         recursion_depth: int = ...,
@@ -1398,16 +1413,18 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        task_ids: Iterable[str],
+        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
         as_pk_tuple: Literal[True, None] = None,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+        state: Union[TaskInstanceState, List[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
+        exclude_task_ids: Collection[str],
+        exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
         session: Session,
         dag_bag: Optional["DagBag"] = None,
         recursion_depth: int = 0,
@@ -1431,6 +1448,11 @@ class DAG(LoggingMixin):
             tis = session.query(TaskInstance)
         tis = tis.join(TaskInstance.dag_run)
 
+        if task_ids is not None:  # task not mapped
+            task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
+        if exclude_task_ids and len(exclude_task_ids) > 0:  # task not mapped
+            exclude_task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
+
         if include_subdags:
             # Crafting the right filter for dag_id and task_ids combo
             conditions = []
@@ -1445,13 +1467,10 @@ class DAG(LoggingMixin):
             tis = tis.filter(TaskInstance.run_id == run_id)
         if start_date:
             tis = tis.filter(DagRun.execution_date >= start_date)
-
-        if task_ids is None:
-            pass  # Disable filter if not set.
-        elif isinstance(next(iter(task_ids), None), str):
-            tis = tis.filter(TI.task_id.in_(task_ids))
-        else:
-            tis = tis.filter(tuple_in_condition((TI.task_id, TI.map_index), task_ids))
+        if task_ids_and_map_indexes:
+            tis = tis.filter(
+                tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids_and_map_indexes)
+            )
 
         # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
         if end_date or not self.allow_future_exec_dates:
@@ -1490,6 +1509,7 @@ class DAG(LoggingMixin):
             result.update(
                 p_dag._get_task_instances(
                     task_ids=task_ids,
+                    task_ids_and_map_indexes=task_ids_and_map_indexes,
                     start_date=start_date,
                     end_date=end_date,
                     run_id=None,
@@ -1499,6 +1519,7 @@ class DAG(LoggingMixin):
                     include_dependent_dags=include_dependent_dags,
                     as_pk_tuple=True,
                     exclude_task_ids=exclude_task_ids,
+                    exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
                     session=session,
                     dag_bag=dag_bag,
                     recursion_depth=recursion_depth,
@@ -1566,7 +1587,7 @@ class DAG(LoggingMixin):
                     )
                     result.update(
                         downstream._get_task_instances(
-                            task_ids=None,
+                            task_ids_and_map_indexes=None,
                             run_id=tii.run_id,
                             start_date=None,
                             end_date=None,
@@ -1575,7 +1596,7 @@ class DAG(LoggingMixin):
                             include_dependent_dags=include_dependent_dags,
                             include_parentdag=False,
                             as_pk_tuple=True,
-                            exclude_task_ids=exclude_task_ids,
+                            exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
                             dag_bag=dag_bag,
                             session=session,
                             recursion_depth=recursion_depth + 1,
@@ -1589,29 +1610,25 @@ class DAG(LoggingMixin):
             if as_pk_tuple:
                 result.update(TaskInstanceKey(*cols) for cols in tis.all())
             else:
-                result.update(ti.key for ti in tis)
-
-            if exclude_task_ids is not None:
-                result = {
-                    task
-                    for task in result
-                    if task.task_id not in exclude_task_ids
-                    and (task.task_id, task.map_index) not in exclude_task_ids
-                }
+                result.update(ti.key for ti in tis.all())
+
+            if exclude_task_ids_and_map_indexes:
+                result = set(
+                    filter(
+                        lambda key: (key.task_id, key.map_index) not in exclude_task_ids_and_map_indexes,
+                        result,
+                    )
+                )
 
         if as_pk_tuple:
             return result
-        if result:
+        elif result:
             # We've been asked for objects, lets combine it all back in to a result set
-            ti_filters = TI.filter_for_tis(result)
-            if ti_filters is not None:
-                tis = session.query(TI).filter(ti_filters)
-        elif exclude_task_ids is None:
-            pass  # Disable filter if not set.
-        elif isinstance(next(iter(exclude_task_ids), None), str):
-            tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
-        else:
-            tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids)))
+            tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
+
+            tis = session.query(TI).filter(TI.filter_for_tis(result))
+        elif exclude_task_ids_and_map_indexes:
+            tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids_and_map_indexes))
 
         return tis
 
@@ -1620,7 +1637,7 @@ class DAG(LoggingMixin):
         self,
         *,
         task_id: str,
-        map_indexes: Optional[Collection[int]] = None,
+        map_indexes: Optional[Iterable[int]] = None,
         execution_date: Optional[datetime] = None,
         run_id: Optional[str] = None,
         state: TaskInstanceState,
@@ -1636,8 +1653,7 @@ class DAG(LoggingMixin):
         in failed or upstream_failed state.
 
         :param task_id: Task ID of the TaskInstance
-        :param map_indexes: Only set TaskInstance if its map_index matches.
-            If None (default), all mapped TaskInstances of the task are set.
+        :param map_indexes: Task instance map_index to set the state of
         :param execution_date: Execution date of the TaskInstance
         :param run_id: The run_id of the TaskInstance
         :param state: State to set the TaskInstance to
@@ -1662,18 +1678,13 @@ class DAG(LoggingMixin):
 
         task = self.get_task(task_id)
         task.dag = self
-
-        tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]]
-        task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]]
-        if map_indexes is None:
-            tasks_to_set_state = [task]
-            task_ids_to_exclude_from_clear = {task_id}
-        else:
-            tasks_to_set_state = [(task, map_index) for map_index in map_indexes]
-            task_ids_to_exclude_from_clear = {(task_id, map_index) for map_index in map_indexes}
+        if not map_indexes:
+            map_indexes = [-1]
+        task_map_indexes = [(task, map_index) for map_index in map_indexes]
+        task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes]
 
         altered = set_state(
-            tasks=tasks_to_set_state,
+            tasks=task_map_indexes,
             execution_date=execution_date,
             run_id=run_id,
             upstream=upstream,
@@ -1703,13 +1714,12 @@ class DAG(LoggingMixin):
         subdag.clear(
             start_date=start_date,
             end_date=end_date,
-            map_indexes=map_indexes,
             include_subdags=True,
             include_parentdag=True,
             only_failed=True,
             session=session,
             # Exclude the task itself from being cleared
-            exclude_task_ids=task_ids_to_exclude_from_clear,
+            exclude_task_ids_and_map_indexes=task_id_map_indexes,
         )
 
         return altered
@@ -1767,7 +1777,8 @@ class DAG(LoggingMixin):
     @provide_session
     def clear(
         self,
-        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None,
+        task_ids=None,
+        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]] = None,
         start_date: Optional[datetime] = None,
         end_date: Optional[datetime] = None,
         only_failed: bool = False,
@@ -1782,13 +1793,14 @@ class DAG(LoggingMixin):
         recursion_depth: int = 0,
         max_recursion_depth: Optional[int] = None,
         dag_bag: Optional["DagBag"] = None,
-        exclude_task_ids: Union[FrozenSet[str], FrozenSet[Tuple[str, int]], None] = frozenset(),
+        exclude_task_ids: FrozenSet[str] = frozenset(),
+        exclude_task_ids_and_map_indexes: FrozenSet[Tuple[str, int]] = frozenset({}),
     ) -> Union[int, Iterable[TaskInstance]]:
         """
         Clears a set of task instances associated with the current dag for
         a specified date range.
 
-        :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear
+        :param task_ids_and_map_indexes: List of tuple of task_id, map_index to clear
         :param start_date: The minimum execution_date to clear
         :param end_date: The maximum execution_date to clear
         :param only_failed: Only clear failed tasks
@@ -1802,7 +1814,8 @@ class DAG(LoggingMixin):
         :param dry_run: Find the tasks to clear but don't clear them.
         :param session: The sqlalchemy session to use
         :param dag_bag: The DagBag used to find the dags subdags (Optional)
-        :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``)
+        :param exclude_task_ids: A set of ``task_id`` that should not be cleared
+        :param exclude_task_ids_and_map_indexes: A set of ``task_id``,``map_index``
             tuples that should not be cleared
         """
         if get_tis:
@@ -1832,9 +1845,12 @@ class DAG(LoggingMixin):
         if only_running:
             # Yes, having `+=` doesn't make sense, but this was the existing behaviour
             state += [State.RUNNING]
+        if task_ids:
+            task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
 
         tis = self._get_task_instances(
             task_ids=task_ids,
+            task_ids_and_map_indexes=task_ids_and_map_indexes,
             start_date=start_date,
             end_date=end_date,
             run_id=None,
@@ -1845,6 +1861,7 @@ class DAG(LoggingMixin):
             session=session,
             dag_bag=dag_bag,
             exclude_task_ids=exclude_task_ids,
+            exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
         )
 
         if dry_run:
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 5e7f7a01e6..7b0f81af80 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1962,7 +1962,7 @@ class Airflow(AirflowBaseView):
         start_date,
         end_date,
         origin,
-        map_indexes=None,
+        task_id_map_index_list=None,
         recursive=False,
         confirmed=False,
         only_failed=False,
@@ -1971,7 +1971,7 @@ class Airflow(AirflowBaseView):
             count = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                map_indexes=map_indexes,
+                task_ids_and_map_indexes=task_id_map_index_list,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -1984,7 +1984,7 @@ class Airflow(AirflowBaseView):
             tis = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                map_indexes=map_indexes,
+                task_ids_and_map_indexes=task_id_map_index_list,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -2026,9 +2026,14 @@ class Airflow(AirflowBaseView):
         task_id = request.form.get('task_id')
         origin = get_safe_url(request.form.get('origin'))
         dag = current_app.dag_bag.get_dag(dag_id)
+
         map_indexes = request.form.get('map_indexes')
-        if map_indexes and not isinstance(map_indexes, list):
-            map_indexes = list(map_indexes)
+        if map_indexes:
+            if not isinstance(map_indexes, list):
+                map_indexes = list(map_indexes)
+        else:
+            map_indexes = [-1]
+        task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes]
 
         execution_date = request.form.get('execution_date')
         execution_date = timezone.parse(execution_date)
@@ -2053,7 +2058,7 @@ class Airflow(AirflowBaseView):
             start_date,
             end_date,
             origin,
-            map_indexes=map_indexes,
+            task_id_map_index_list=task_id_map_indexes,
             recursive=recursive,
             confirmed=confirmed,
             only_failed=only_failed,
@@ -2072,9 +2077,6 @@ class Airflow(AirflowBaseView):
         dag_id = request.form.get('dag_id')
         dag_run_id = request.form.get('dag_run_id')
         confirmed = request.form.get('confirmed') == "true"
-        map_indexes = request.form.get('map_indexes')
-        if map_indexes and not isinstance(map_indexes, list):
-            map_indexes = list(map_indexes)
 
         dag = current_app.dag_bag.get_dag(dag_id)
         dr = dag.get_dagrun(run_id=dag_run_id)
@@ -2085,7 +2087,6 @@ class Airflow(AirflowBaseView):
             dag,
             start_date,
             end_date,
-            map_indexes=map_indexes,
             origin=None,
             recursive=True,
             confirmed=confirmed,
@@ -2339,8 +2340,11 @@ class Airflow(AirflowBaseView):
         state = args.get('state')
         origin = args.get('origin')
         map_indexes = args.get('map_indexes')
-        if map_indexes and not isinstance(map_indexes, list):
-            map_indexes = list(map_indexes)
+        if map_indexes:
+            if not isinstance(map_indexes, list):
+                map_indexes = list(map_indexes)
+        else:
+            map_indexes = [-1]
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2376,7 +2380,7 @@ class Airflow(AirflowBaseView):
         from airflow.api.common.mark_tasks import set_state
 
         to_be_altered = set_state(
-            tasks=[task],
+            tasks=[(task, map_index) for map_index in map_indexes],
             map_indexes=map_indexes,
             run_id=dag_run_id,
             upstream=upstream,
@@ -2418,8 +2422,11 @@ class Airflow(AirflowBaseView):
         origin = get_safe_url(args.get('origin'))
         dag_run_id = args.get('dag_run_id')
         map_indexes = args.get('map_indexes')
-        if map_indexes and not isinstance(map_indexes, list):
-            map_indexes = list(map_indexes)
+        if map_indexes:
+            if not isinstance(map_indexes, list):
+                map_indexes = list(map_indexes)
+        else:
+            map_indexes = [-1]
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2455,8 +2462,11 @@ class Airflow(AirflowBaseView):
         origin = get_safe_url(args.get('origin'))
         dag_run_id = args.get('dag_run_id')
         map_indexes = args.get('map_indexes')
-        if map_indexes and not isinstance(map_indexes, list):
-            map_indexes = list(map_indexes)
+        if map_indexes:
+            if not isinstance(map_indexes, list):
+                map_indexes = list(map_indexes)
+        else:
+            map_indexes = [-1]
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
diff --git a/tests/api/common/test_mark_tasks.py b/tests/api/common/test_mark_tasks.py
index dedb624b1f..4c1d3c604b 100644
--- a/tests/api/common/test_mark_tasks.py
+++ b/tests/api/common/test_mark_tasks.py
@@ -439,12 +439,12 @@ class TestMarkTasks:
     def test_mark_mapped_task_instance_state(self):
         # set mapped task instance to success
         snapshot = TestMarkTasks.snapshot_state(self.dag4, self.execution_dates)
-        tasks = [self.dag4.get_task("consumer_literal")]
+        task = self.dag4.get_task("consumer_literal")
+        tasks = [(task, 0), (task, 1)]
         map_indexes = [0, 1]
         dr = DagRun.find(dag_id=self.dag4.dag_id, execution_date=self.execution_dates[0])[0]
         altered = set_state(
             tasks=tasks,
-            map_indexes=map_indexes,
             run_id=dr.run_id,
             upstream=False,
             downstream=False,
@@ -456,7 +456,7 @@ class TestMarkTasks:
         assert len(altered) == 2
         self.verify_state(
             self.dag4,
-            [task.task_id for task in tasks],
+            [task.task_id for task, _ in tasks],
             [self.execution_dates[0]],
             State.SUCCESS,
             snapshot,
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 41219ef87b..ed2119a490 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1456,7 +1456,7 @@ class TestDag(unittest.TestCase):
         session.flush()
 
         dag.clear(
-            map_indexes=[0],
+            task_ids_and_map_indexes=[(task_id, 0)],
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=1),
             dag_run_state=dag_run_state,


[airflow] 02/19: fixup! fixup! Allow marking/clearing mapped taskinstances from the UI

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 334e39923eb8e96fa6e406c427ea604a07c974b4
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Apr 12 17:02:07 2022 +0100

    fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
---
 airflow/www/views.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/airflow/www/views.py b/airflow/www/views.py
index 7be1289144..5e7f7a01e6 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -2310,7 +2310,7 @@ class Airflow(AirflowBaseView):
 
         altered = dag.set_task_instance_state(
             task_id=task_id,
-            map_index=map_indexes,
+            map_indexes=map_indexes,
             run_id=dag_run_id,
             state=state,
             upstream=upstream,


[airflow] 01/19: fixup! Allow marking/clearing mapped taskinstances from the UI

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 5620df94692a8802202789b21a05104999f8494c
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Apr 12 16:56:11 2022 +0100

    fixup! Allow marking/clearing mapped taskinstances from the UI
---
 airflow/www/views.py | 168 ++++++++++++++++++++++++---------------------------
 1 file changed, 80 insertions(+), 88 deletions(-)

diff --git a/airflow/www/views.py b/airflow/www/views.py
index ae0186e493..7be1289144 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -95,7 +95,6 @@ from airflow.api.common.mark_tasks import (
     set_dag_run_state_to_failed,
     set_dag_run_state_to_queued,
     set_dag_run_state_to_success,
-    set_state,
 )
 from airflow.compat.functools import cached_property
 from airflow.configuration import AIRFLOW_CONFIG, conf
@@ -108,7 +107,6 @@ from airflow.models import DAG, Connection, DagModel, DagTag, Log, SlaMiss, Task
 from airflow.models.abstractoperator import AbstractOperator
 from airflow.models.dagcode import DagCode
 from airflow.models.dagrun import DagRun, DagRunType
-from airflow.models.operator import Operator
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskinstance import TaskInstance
 from airflow.providers_manager import ProvidersManager
@@ -1960,11 +1958,11 @@ class Airflow(AirflowBaseView):
 
     def _clear_dag_tis(
         self,
-        dag: DAG,
+        dag,
         start_date,
         end_date,
         origin,
-        task_ids=None,
+        map_indexes=None,
         recursive=False,
         confirmed=False,
         only_failed=False,
@@ -1973,7 +1971,7 @@ class Airflow(AirflowBaseView):
             count = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                task_ids=task_ids,
+                map_indexes=map_indexes,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -1986,7 +1984,7 @@ class Airflow(AirflowBaseView):
             tis = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                task_ids=task_ids,
+                map_indexes=map_indexes,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -1995,19 +1993,24 @@ class Airflow(AirflowBaseView):
         except AirflowException as ex:
             return redirect_or_json(origin, msg=str(ex), status="error")
 
-        assert isinstance(tis, collections.abc.Iterable)
-        details = [str(t) for t in tis]
-
-        if not details:
-            return redirect_or_json(origin, "No task instances to clear", status="error")
+        if not tis:
+            msg = "No task instances to clear"
+            return redirect_or_json(origin, msg, status="error")
         elif request.headers.get('Accept') == 'application/json':
+            details = [str(t) for t in tis]
+
             return htmlsafe_json_dumps(details, separators=(',', ':'))
-        return self.render_template(
-            'airflow/confirm.html',
-            endpoint=None,
-            message="Task instances you are about to clear:",
-            details="\n".join(details),
-        )
+        else:
+            details = "\n".join(str(t) for t in tis)
+
+            response = self.render_template(
+                'airflow/confirm.html',
+                endpoint=None,
+                message="Task instances you are about to clear:",
+                details=details,
+            )
+
+        return response
 
     @expose('/clear', methods=['POST'])
     @auth.has_access(
@@ -2023,11 +2026,9 @@ class Airflow(AirflowBaseView):
         task_id = request.form.get('task_id')
         origin = get_safe_url(request.form.get('origin'))
         dag = current_app.dag_bag.get_dag(dag_id)
-
-        if 'map_index' not in request.form:
-            map_indexes: Optional[List[int]] = None
-        else:
-            map_indexes = request.form.getlist('map_index', type=int)
+        map_indexes = request.form.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
 
         execution_date = request.form.get('execution_date')
         execution_date = timezone.parse(execution_date)
@@ -2047,17 +2048,12 @@ class Airflow(AirflowBaseView):
         end_date = execution_date if not future else None
         start_date = execution_date if not past else None
 
-        if map_indexes is None:
-            task_ids: Union[List[str], List[Tuple[str, int]]] = [task_id]
-        else:
-            task_ids = [(task_id, map_index) for map_index in map_indexes]
-
         return self._clear_dag_tis(
             dag,
             start_date,
             end_date,
             origin,
-            task_ids=task_ids,
+            map_indexes=map_indexes,
             recursive=recursive,
             confirmed=confirmed,
             only_failed=only_failed,
@@ -2076,6 +2072,9 @@ class Airflow(AirflowBaseView):
         dag_id = request.form.get('dag_id')
         dag_run_id = request.form.get('dag_run_id')
         confirmed = request.form.get('confirmed') == "true"
+        map_indexes = request.form.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
 
         dag = current_app.dag_bag.get_dag(dag_id)
         dr = dag.get_dagrun(run_id=dag_run_id)
@@ -2086,6 +2085,7 @@ class Airflow(AirflowBaseView):
             dag,
             start_date,
             end_date,
+            map_indexes=map_indexes,
             origin=None,
             recursive=True,
             confirmed=confirmed,
@@ -2290,28 +2290,28 @@ class Airflow(AirflowBaseView):
 
     def _mark_task_instance_state(
         self,
-        *,
-        dag_id: str,
-        run_id: str,
-        task_id: str,
-        map_indexes: Optional[List[int]],
-        origin: str,
-        upstream: bool,
-        downstream: bool,
-        future: bool,
-        past: bool,
-        state: TaskInstanceState,
+        dag_id,
+        task_id,
+        map_indexes,
+        origin,
+        dag_run_id,
+        upstream,
+        downstream,
+        future,
+        past,
+        state,
     ):
-        dag: DAG = current_app.dag_bag.get_dag(dag_id)
+        dag = current_app.dag_bag.get_dag(dag_id)
+        latest_execution_date = dag.get_latest_execution_date()
 
-        if not run_id:
-            flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has never run", "error")
+        if not latest_execution_date:
+            flash(f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run", "error")
             return redirect(origin)
 
         altered = dag.set_task_instance_state(
             task_id=task_id,
-            map_indexes=map_indexes,
-            run_id=run_id,
+            map_index=map_indexes,
+            run_id=dag_run_id,
             state=state,
             upstream=upstream,
             downstream=downstream,
@@ -2338,11 +2338,9 @@ class Airflow(AirflowBaseView):
         dag_run_id = args.get('dag_run_id')
         state = args.get('state')
         origin = args.get('origin')
-
-        if 'map_index' not in args:
-            map_indexes: Optional[List[int]] = None
-        else:
-            map_indexes = args.getlist('map_index', type=int)
+        map_indexes = args.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2375,13 +2373,11 @@ class Airflow(AirflowBaseView):
             msg = f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run"
             return redirect_or_json(origin, msg, status='error')
 
-        if map_indexes is None:
-            tasks: Union[List[Operator], List[Tuple[Operator, int]]] = [task]
-        else:
-            tasks = [(task, map_index) for map_index in map_indexes]
+        from airflow.api.common.mark_tasks import set_state
 
         to_be_altered = set_state(
-            tasks=tasks,
+            tasks=[task],
+            map_indexes=map_indexes,
             run_id=dag_run_id,
             upstream=upstream,
             downstream=downstream,
@@ -2419,30 +2415,28 @@ class Airflow(AirflowBaseView):
         args = request.form
         dag_id = args.get('dag_id')
         task_id = args.get('task_id')
-        run_id = args.get('dag_run_id')
-
-        if 'map_index' not in args:
-            map_indexes: Optional[List[int]] = None
-        else:
-            map_indexes = args.getlist('map_index', type=int)
-
         origin = get_safe_url(args.get('origin'))
+        dag_run_id = args.get('dag_run_id')
+        map_indexes = args.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
+
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
         future = to_boolean(args.get('future'))
         past = to_boolean(args.get('past'))
 
         return self._mark_task_instance_state(
-            dag_id=dag_id,
-            run_id=run_id,
-            task_id=task_id,
-            map_indexes=map_indexes,
-            origin=origin,
-            upstream=upstream,
-            downstream=downstream,
-            future=future,
-            past=past,
-            state=TaskInstanceState.FAILED,
+            dag_id,
+            task_id,
+            map_indexes,
+            origin,
+            dag_run_id,
+            upstream,
+            downstream,
+            future,
+            past,
+            State.FAILED,
         )
 
     @expose('/success', methods=['POST'])
@@ -2458,30 +2452,28 @@ class Airflow(AirflowBaseView):
         args = request.form
         dag_id = args.get('dag_id')
         task_id = args.get('task_id')
-        run_id = args.get('dag_run_id')
-
-        if 'map_index' not in args:
-            map_indexes: Optional[List[int]] = None
-        else:
-            map_indexes = args.getlist('map_index', type=int)
-
         origin = get_safe_url(args.get('origin'))
+        dag_run_id = args.get('dag_run_id')
+        map_indexes = args.get('map_indexes')
+        if map_indexes and not isinstance(map_indexes, list):
+            map_indexes = list(map_indexes)
+
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
         future = to_boolean(args.get('future'))
         past = to_boolean(args.get('past'))
 
         return self._mark_task_instance_state(
-            dag_id=dag_id,
-            run_id=run_id,
-            task_id=task_id,
-            map_indexes=map_indexes,
-            origin=origin,
-            upstream=upstream,
-            downstream=downstream,
-            future=future,
-            past=past,
-            state=TaskInstanceState.SUCCESS,
+            dag_id,
+            task_id,
+            map_indexes,
+            origin,
+            dag_run_id,
+            upstream,
+            downstream,
+            future,
+            past,
+            State.SUCCESS,
         )
 
     @expose('/dags/<string:dag_id>')


[airflow] 03/19: fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 83413da968849574ed78d78918d0d08ed82563b9
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Apr 12 17:07:05 2022 +0100

    fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
---
 airflow/models/dag.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 83860ba591..e9c33acb72 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1703,6 +1703,7 @@ class DAG(LoggingMixin):
         subdag.clear(
             start_date=start_date,
             end_date=end_date,
+            map_indexes=map_indexes,
             include_subdags=True,
             include_parentdag=True,
             only_failed=True,


[airflow] 07/19: fixup! fixup! fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 5bbcb9f17f92e763a35a08c6522a2031506dd91c
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Apr 13 21:42:09 2022 +0100

    fixup! fixup! fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
---
 airflow/www/views.py | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/airflow/www/views.py b/airflow/www/views.py
index 7b0f81af80..2a3ad1e913 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -2293,7 +2293,6 @@ class Airflow(AirflowBaseView):
         self,
         dag_id,
         task_id,
-        map_indexes,
         origin,
         dag_run_id,
         upstream,
@@ -2301,6 +2300,7 @@ class Airflow(AirflowBaseView):
         future,
         past,
         state,
+        map_indexes=None,
     ):
         dag = current_app.dag_bag.get_dag(dag_id)
         latest_execution_date = dag.get_latest_execution_date()
@@ -2381,7 +2381,6 @@ class Airflow(AirflowBaseView):
 
         to_be_altered = set_state(
             tasks=[(task, map_index) for map_index in map_indexes],
-            map_indexes=map_indexes,
             run_id=dag_run_id,
             upstream=upstream,
             downstream=downstream,
@@ -2436,7 +2435,6 @@ class Airflow(AirflowBaseView):
         return self._mark_task_instance_state(
             dag_id,
             task_id,
-            map_indexes,
             origin,
             dag_run_id,
             upstream,
@@ -2444,6 +2442,7 @@ class Airflow(AirflowBaseView):
             future,
             past,
             State.FAILED,
+            map_indexes=map_indexes,
         )
 
     @expose('/success', methods=['POST'])
@@ -2476,7 +2475,6 @@ class Airflow(AirflowBaseView):
         return self._mark_task_instance_state(
             dag_id,
             task_id,
-            map_indexes,
             origin,
             dag_run_id,
             upstream,
@@ -2484,6 +2482,7 @@ class Airflow(AirflowBaseView):
             future,
             past,
             State.SUCCESS,
+            map_indexes=map_indexes,
         )
 
     @expose('/dags/<string:dag_id>')


[airflow] 09/19: fixup! Apply suggestions from code review

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 4e81a780308a52e16f7fdd2ca8de8e2053636b7c
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Fri Apr 15 12:26:53 2022 +0100

    fixup! Apply suggestions from code review
---
 airflow/api/common/mark_tasks.py | 27 +++++++++-----
 airflow/models/dag.py            | 79 +++++++++++++++++++---------------------
 airflow/www/views.py             | 52 ++++++++------------------
 tests/models/test_dag.py         |  2 +-
 4 files changed, 73 insertions(+), 87 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index a9e4f4812e..594423305c 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -119,10 +119,7 @@ def set_state(
     if execution_date and not timezone.is_localized(execution_date):
         raise ValueError(f"Received non-localized date {execution_date}")
 
-    task_dags = {
-        task[0].dag if isinstance(task, tuple) else task.dag
-        for task in tasks
-    }
+    task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks}
     if len(task_dags) > 1:
         raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
     dag = next(iter(task_dags))
@@ -137,6 +134,12 @@ def set_state(
     dag_run_ids = get_run_ids(dag, run_id, future, past)
     task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
     task_ids = [task_id for task_id, _ in task_id_map_index_list]
+    # check if task_id_map_index_list contains map_index of None
+    # if it contains None, there was no map_index supplied for the task
+    for _, index in task_id_map_index_list:
+        if index is None:
+            task_id_map_index_list = [task_id for task_id, _ in task_id_map_index_list]
+            break
 
     confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
     confirmed_dates = [info.logical_date for info in confirmed_infos]
@@ -183,20 +186,26 @@ def get_all_dag_task_query(
     dag: DAG,
     session: SASession,
     state: TaskInstanceState,
-    task_id_map_index_list: List[Tuple[str, int]],
+    task_ids: Union[List[str], List[Tuple[str, int]]],
     confirmed_dates: Iterable[datetime],
 ):
     """Get all tasks of the main dag that will be affected by a state change"""
+    is_string_list = isinstance(task_ids[0], str)
     qry_dag = (
         session.query(TaskInstance)
         .join(TaskInstance.dag_run)
         .filter(
             TaskInstance.dag_id == dag.dag_id,
             DagRun.execution_date.in_(confirmed_dates),
-            tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_id_map_index_list),
         )
-        .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
-        .options(contains_eager(TaskInstance.dag_run))
+    )
+
+    if is_string_list:
+        qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids))
+    else:
+        qry_dag = qry_dag.filter(tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids))
+    qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
+        contains_eager(TaskInstance.dag_run)
     )
     return qry_dag
 
@@ -278,7 +287,7 @@ def find_task_relatives(tasks, downstream, upstream):
         if isinstance(item, tuple):
             task, map_index = item
         else:
-            task, map_index = item, -1
+            task, map_index = item, None
         yield task.task_id, map_index
         if downstream:
             for relative in task.get_flat_relatives(upstream=False):
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 8856a841d3..0694f37550 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1349,7 +1349,6 @@ class DAG(LoggingMixin):
                 Query,
                 self._get_task_instances(
                     task_ids=None,
-                    task_ids_and_map_indexes=None,
                     start_date=start_date,
                     end_date=end_date,
                     run_id=None,
@@ -1358,7 +1357,6 @@ class DAG(LoggingMixin):
                     include_parentdag=False,
                     include_dependent_dags=False,
                     exclude_task_ids=cast(List[str], []),
-                    exclude_task_ids_and_map_indexes=None,
                     session=session,
                 ),
             )
@@ -1370,7 +1368,7 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids,
+        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
         task_ids_and_map_indexes,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
@@ -1379,8 +1377,7 @@ class DAG(LoggingMixin):
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Collection[str],
-        exclude_task_ids_and_map_indexes,
+        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
         session: Session,
         dag_bag: Optional["DagBag"] = ...,
     ) -> Iterable[TaskInstance]:
@@ -1390,8 +1387,7 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids,
-        task_ids_and_map_indexes,
+        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
         as_pk_tuple: Literal[True],
         start_date: Optional[datetime],
         end_date: Optional[datetime],
@@ -1400,8 +1396,7 @@ class DAG(LoggingMixin):
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Collection[str],
-        exclude_task_ids_and_map_indexes,
+        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
         session: Session,
         dag_bag: Optional["DagBag"] = ...,
         recursion_depth: int = ...,
@@ -1413,8 +1408,7 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids,
-        task_ids_and_map_indexes,
+        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
         as_pk_tuple: Literal[True, None] = None,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
@@ -1423,8 +1417,7 @@ class DAG(LoggingMixin):
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
-        exclude_task_ids: Collection[str],
-        exclude_task_ids_and_map_indexes,
+        exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
         session: Session,
         dag_bag: Optional["DagBag"] = None,
         recursion_depth: int = 0,
@@ -1448,10 +1441,17 @@ class DAG(LoggingMixin):
             tis = session.query(TaskInstance)
         tis = tis.join(TaskInstance.dag_run)
 
-        if task_ids is not None:  # task not mapped
-            task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
-        if exclude_task_ids and len(exclude_task_ids) > 0:  # task not mapped
-            exclude_task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
+        task_ids_and_map_indexes = None
+        if task_ids is not None:
+            task_ids_and_map_indexes = [item for item in task_ids if isinstance(item, tuple)]
+        if task_ids_and_map_indexes:
+            task_ids = None  # nullify since we have indexes
+
+        exclude_task_ids_and_map_indexes = None
+        if exclude_task_ids is not None:
+            exclude_task_ids_and_map_indexes = [item for item in exclude_task_ids if isinstance(item, tuple)]
+        if exclude_task_ids_and_map_indexes:
+            exclude_task_ids = None
 
         if include_subdags:
             # Crafting the right filter for dag_id and task_ids combo
@@ -1467,6 +1467,8 @@ class DAG(LoggingMixin):
             tis = tis.filter(TaskInstance.run_id == run_id)
         if start_date:
             tis = tis.filter(DagRun.execution_date >= start_date)
+        if task_ids:
+            tis = tis.filter(TaskInstance.task_id.in_(task_ids))
         if task_ids_and_map_indexes:
             tis = tis.filter(
                 tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids_and_map_indexes)
@@ -1509,7 +1511,6 @@ class DAG(LoggingMixin):
             result.update(
                 p_dag._get_task_instances(
                     task_ids=task_ids,
-                    task_ids_and_map_indexes=task_ids_and_map_indexes,
                     start_date=start_date,
                     end_date=end_date,
                     run_id=None,
@@ -1519,7 +1520,6 @@ class DAG(LoggingMixin):
                     include_dependent_dags=include_dependent_dags,
                     as_pk_tuple=True,
                     exclude_task_ids=exclude_task_ids,
-                    exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
                     session=session,
                     dag_bag=dag_bag,
                     recursion_depth=recursion_depth,
@@ -1588,7 +1588,6 @@ class DAG(LoggingMixin):
                     result.update(
                         downstream._get_task_instances(
                             task_ids=None,
-                            task_ids_and_map_indexes=None,
                             run_id=tii.run_id,
                             start_date=None,
                             end_date=None,
@@ -1598,7 +1597,6 @@ class DAG(LoggingMixin):
                             include_parentdag=False,
                             as_pk_tuple=True,
                             exclude_task_ids=exclude_task_ids,
-                            exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
                             dag_bag=dag_bag,
                             session=session,
                             recursion_depth=recursion_depth + 1,
@@ -1614,7 +1612,15 @@ class DAG(LoggingMixin):
             else:
                 result.update(ti.key for ti in tis.all())
 
-            if exclude_task_ids_and_map_indexes:
+            if exclude_task_ids is not None:
+                result = set(
+                    filter(
+                        lambda key: key.task_id not in exclude_task_ids,
+                        result,
+                    )
+                )
+
+            if exclude_task_ids_and_map_indexes is not None:
                 result = set(
                     filter(
                         lambda key: (key.task_id, key.map_index) not in exclude_task_ids_and_map_indexes,
@@ -1639,7 +1645,7 @@ class DAG(LoggingMixin):
         self,
         *,
         task_id: str,
-        map_indexes: Optional[Iterable[int]] = None,
+        map_index: Optional[int] = None,
         execution_date: Optional[datetime] = None,
         run_id: Optional[str] = None,
         state: TaskInstanceState,
@@ -1655,7 +1661,8 @@ class DAG(LoggingMixin):
         in failed or upstream_failed state.
 
         :param task_id: Task ID of the TaskInstance
-        :param map_indexes: Task instance map_index to set the state of
+        :param map_index: The TaskInstance map_index, if None, would set state for all mapped
+            TaskInstances of the task
         :param execution_date: Execution date of the TaskInstance
         :param run_id: The run_id of the TaskInstance
         :param state: State to set the TaskInstance to
@@ -1680,10 +1687,8 @@ class DAG(LoggingMixin):
 
         task = self.get_task(task_id)
         task.dag = self
-        if not map_indexes:
-            map_indexes = [-1]
-        task_map_indexes = [(task, map_index) for map_index in map_indexes]
-        task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes]
+        task_map_indexes = [(task, map_index)] if map_index else [task]
+        task_id_map_indexes = {(task_id, map_index)} if map_index else {task_id}
 
         altered = set_state(
             tasks=task_map_indexes,
@@ -1721,7 +1726,7 @@ class DAG(LoggingMixin):
             only_failed=True,
             session=session,
             # Exclude the task itself from being cleared
-            exclude_task_ids_and_map_indexes=task_id_map_indexes,
+            exclude_task_ids=task_id_map_indexes,
         )
 
         return altered
@@ -1779,8 +1784,7 @@ class DAG(LoggingMixin):
     @provide_session
     def clear(
         self,
-        task_ids=None,
-        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]] = None,
+        task_ids: Union[Iterable[str], Iterable[Tuple[str, int]], None] = None,
         start_date: Optional[datetime] = None,
         end_date: Optional[datetime] = None,
         only_failed: bool = False,
@@ -1795,15 +1799,13 @@ class DAG(LoggingMixin):
         recursion_depth: int = 0,
         max_recursion_depth: Optional[int] = None,
         dag_bag: Optional["DagBag"] = None,
-        exclude_task_ids: FrozenSet[str] = frozenset(),
-        exclude_task_ids_and_map_indexes: FrozenSet[Tuple[str, int]] = frozenset({}),
+        exclude_task_ids: Union[FrozenSet[str], FrozenSet[Tuple[str, int]], None] = frozenset(),
     ) -> Union[int, Iterable[TaskInstance]]:
         """
         Clears a set of task instances associated with the current dag for
         a specified date range.
 
-        :param task_ids: List of task ids to clear
-        :param task_ids_and_map_indexes: List of tuple of task_id, map_index to clear
+        :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear
         :param start_date: The minimum execution_date to clear
         :param end_date: The maximum execution_date to clear
         :param only_failed: Only clear failed tasks
@@ -1817,8 +1819,7 @@ class DAG(LoggingMixin):
         :param dry_run: Find the tasks to clear but don't clear them.
         :param session: The sqlalchemy session to use
         :param dag_bag: The DagBag used to find the dags subdags (Optional)
-        :param exclude_task_ids: A set of ``task_id`` that should not be cleared
-        :param exclude_task_ids_and_map_indexes: A set of ``task_id``,``map_index``
+        :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``)
             tuples that should not be cleared
         """
         if get_tis:
@@ -1848,12 +1849,9 @@ class DAG(LoggingMixin):
         if only_running:
             # Yes, having `+=` doesn't make sense, but this was the existing behaviour
             state += [State.RUNNING]
-        if task_ids:
-            task_ids_and_map_indexes = [(task_id, -1) for task_id in task_ids]
 
         tis = self._get_task_instances(
             task_ids=task_ids,
-            task_ids_and_map_indexes=task_ids_and_map_indexes,
             start_date=start_date,
             end_date=end_date,
             run_id=None,
@@ -1864,7 +1862,6 @@ class DAG(LoggingMixin):
             session=session,
             dag_bag=dag_bag,
             exclude_task_ids=exclude_task_ids,
-            exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
         )
 
         if dry_run:
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 2a3ad1e913..437a60cca0 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1962,7 +1962,7 @@ class Airflow(AirflowBaseView):
         start_date,
         end_date,
         origin,
-        task_id_map_index_list=None,
+        task_ids=None,
         recursive=False,
         confirmed=False,
         only_failed=False,
@@ -1971,7 +1971,7 @@ class Airflow(AirflowBaseView):
             count = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                task_ids_and_map_indexes=task_id_map_index_list,
+                task_ids=task_ids,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -1984,7 +1984,7 @@ class Airflow(AirflowBaseView):
             tis = dag.clear(
                 start_date=start_date,
                 end_date=end_date,
-                task_ids_and_map_indexes=task_id_map_index_list,
+                task_ids=task_ids,
                 include_subdags=recursive,
                 include_parentdag=recursive,
                 only_failed=only_failed,
@@ -2026,14 +2026,7 @@ class Airflow(AirflowBaseView):
         task_id = request.form.get('task_id')
         origin = get_safe_url(request.form.get('origin'))
         dag = current_app.dag_bag.get_dag(dag_id)
-
-        map_indexes = request.form.get('map_indexes')
-        if map_indexes:
-            if not isinstance(map_indexes, list):
-                map_indexes = list(map_indexes)
-        else:
-            map_indexes = [-1]
-        task_id_map_indexes = [(task_id, map_index) for map_index in map_indexes]
+        map_index = request.form.get('map_index')
 
         execution_date = request.form.get('execution_date')
         execution_date = timezone.parse(execution_date)
@@ -2052,13 +2045,13 @@ class Airflow(AirflowBaseView):
         )
         end_date = execution_date if not future else None
         start_date = execution_date if not past else None
-
+        task_ids = [(task_id, map_index)] if map_index else [task_id]
         return self._clear_dag_tis(
             dag,
             start_date,
             end_date,
             origin,
-            task_id_map_index_list=task_id_map_indexes,
+            task_ids=task_ids,
             recursive=recursive,
             confirmed=confirmed,
             only_failed=only_failed,
@@ -2300,7 +2293,7 @@ class Airflow(AirflowBaseView):
         future,
         past,
         state,
-        map_indexes=None,
+        map_index=None,
     ):
         dag = current_app.dag_bag.get_dag(dag_id)
         latest_execution_date = dag.get_latest_execution_date()
@@ -2311,7 +2304,7 @@ class Airflow(AirflowBaseView):
 
         altered = dag.set_task_instance_state(
             task_id=task_id,
-            map_indexes=map_indexes,
+            map_index=map_index,
             run_id=dag_run_id,
             state=state,
             upstream=upstream,
@@ -2339,12 +2332,7 @@ class Airflow(AirflowBaseView):
         dag_run_id = args.get('dag_run_id')
         state = args.get('state')
         origin = args.get('origin')
-        map_indexes = args.get('map_indexes')
-        if map_indexes:
-            if not isinstance(map_indexes, list):
-                map_indexes = list(map_indexes)
-        else:
-            map_indexes = [-1]
+        map_index = args.get('map_index')
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2379,8 +2367,10 @@ class Airflow(AirflowBaseView):
 
         from airflow.api.common.mark_tasks import set_state
 
+        tasks = [(task, map_index)] if map_index else [task]
+
         to_be_altered = set_state(
-            tasks=[(task, map_index) for map_index in map_indexes],
+            tasks=tasks,
             run_id=dag_run_id,
             upstream=upstream,
             downstream=downstream,
@@ -2420,12 +2410,7 @@ class Airflow(AirflowBaseView):
         task_id = args.get('task_id')
         origin = get_safe_url(args.get('origin'))
         dag_run_id = args.get('dag_run_id')
-        map_indexes = args.get('map_indexes')
-        if map_indexes:
-            if not isinstance(map_indexes, list):
-                map_indexes = list(map_indexes)
-        else:
-            map_indexes = [-1]
+        map_index = args.get('map_index')
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2442,7 +2427,7 @@ class Airflow(AirflowBaseView):
             future,
             past,
             State.FAILED,
-            map_indexes=map_indexes,
+            map_index=map_index,
         )
 
     @expose('/success', methods=['POST'])
@@ -2460,12 +2445,7 @@ class Airflow(AirflowBaseView):
         task_id = args.get('task_id')
         origin = get_safe_url(args.get('origin'))
         dag_run_id = args.get('dag_run_id')
-        map_indexes = args.get('map_indexes')
-        if map_indexes:
-            if not isinstance(map_indexes, list):
-                map_indexes = list(map_indexes)
-        else:
-            map_indexes = [-1]
+        map_index = args.get('map_index')
 
         upstream = to_boolean(args.get('upstream'))
         downstream = to_boolean(args.get('downstream'))
@@ -2482,7 +2462,7 @@ class Airflow(AirflowBaseView):
             future,
             past,
             State.SUCCESS,
-            map_indexes=map_indexes,
+            map_index=map_index,
         )
 
     @expose('/dags/<string:dag_id>')
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index ed2119a490..6cd8ea660f 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1456,7 +1456,7 @@ class TestDag(unittest.TestCase):
         session.flush()
 
         dag.clear(
-            task_ids_and_map_indexes=[(task_id, 0)],
+            task_ids=[(task_id, 0)],
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=1),
             dag_run_state=dag_run_state,


[airflow] 15/19: Allow bulk mapped task actions

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit b77a9203353077c154a9c9fc8d3c9b099726adb9
Author: Brent Bovenzi <br...@gmail.com>
AuthorDate: Sat Apr 9 14:50:23 2022 -0400

    Allow bulk mapped task actions
---
 airflow/www/static/js/tree/Table.jsx               | 44 ++++++++++++++++++++--
 .../content/taskInstance/MappedInstances.jsx       |  3 +-
 .../js/tree/details/content/taskInstance/index.jsx | 25 +++++++++++-
 .../content/taskInstance/taskActions/Clear.jsx     |  2 +
 .../taskInstance/taskActions/MarkFailed.jsx        |  3 +-
 .../taskInstance/taskActions/MarkSuccess.jsx       |  4 +-
 .../content/taskInstance/taskActions/Run.jsx       | 22 ++++++++---
 7 files changed, 88 insertions(+), 15 deletions(-)

diff --git a/airflow/www/static/js/tree/Table.jsx b/airflow/www/static/js/tree/Table.jsx
index aef91ce905..152f647ea3 100644
--- a/airflow/www/static/js/tree/Table.jsx
+++ b/airflow/www/static/js/tree/Table.jsx
@@ -21,7 +21,7 @@
  * Custom wrapper of react-table using Chakra UI components
 */
 
-import React, { useEffect } from 'react';
+import React, { useEffect, useRef, forwardRef } from 'react';
 import {
   Flex,
   Table as ChakraTable,
@@ -33,9 +33,10 @@ import {
   IconButton,
   Text,
   useColorModeValue,
+  Checkbox,
 } from '@chakra-ui/react';
 import {
-  useTable, useSortBy, usePagination,
+  useTable, useSortBy, usePagination, useRowSelect,
 } from 'react-table';
 import {
   MdKeyboardArrowLeft, MdKeyboardArrowRight,
@@ -44,8 +45,23 @@ import {
   TiArrowUnsorted, TiArrowSortedDown, TiArrowSortedUp,
 } from 'react-icons/ti';
 
+const IndeterminateCheckbox = forwardRef(
+  ({ indeterminate, ...rest }, ref) => {
+    const defaultRef = useRef();
+    const resolvedRef = ref || defaultRef;
+
+    useEffect(() => {
+      resolvedRef.current.indeterminate = indeterminate;
+    }, [resolvedRef, indeterminate]);
+
+    return (
+      <Checkbox ref={resolvedRef} {...rest} />
+    );
+  },
+);
+
 const Table = ({
-  data, columns, manualPagination, pageSize = 25, setSortBy, isLoading = false,
+  data, columns, manualPagination, pageSize = 25, setSortBy, isLoading = false, selectRows,
 }) => {
   const { totalEntries, offset, setOffset } = manualPagination || {};
   const oddColor = useColorModeValue('gray.50', 'gray.900');
@@ -66,7 +82,8 @@ const Table = ({
     canNextPage,
     nextPage,
     previousPage,
-    state: { pageIndex, sortBy },
+    selectedFlatRows,
+    state: { pageIndex, sortBy, selectedRowIds },
   } = useTable(
     {
       columns,
@@ -81,6 +98,20 @@ const Table = ({
     },
     useSortBy,
     usePagination,
+    useRowSelect,
+    (hooks) => {
+      hooks.visibleColumns.push((cols) => [
+        {
+          id: 'selection',
+          Cell: ({ row }) => (
+            <div>
+              <IndeterminateCheckbox {...row.getToggleRowSelectedProps()} />
+            </div>
+          ),
+        },
+        ...cols,
+      ]);
+    },
   );
 
   const handleNext = () => {
@@ -97,6 +128,11 @@ const Table = ({
     if (setSortBy) setSortBy(sortBy);
   }, [sortBy, setSortBy]);
 
+  useEffect(() => {
+    if (selectRows) selectRows(selectedFlatRows.map((row) => row.original.mapIndex));
+  // eslint-disable-next-line react-hooks/exhaustive-deps
+  }, [selectedRowIds, selectRows]);
+
   return (
     <>
       <ChakraTable {...getTableProps()}>
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/MappedInstances.jsx b/airflow/www/static/js/tree/details/content/taskInstance/MappedInstances.jsx
index 42bbdca66f..77c0713ab3 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/MappedInstances.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/MappedInstances.jsx
@@ -46,7 +46,7 @@ const IconLink = (props) => (
 );
 
 const MappedInstances = ({
-  dagId, runId, taskId,
+  dagId, runId, taskId, selectRows,
 }) => {
   const limit = 25;
   const [offset, setOffset] = useState(0);
@@ -147,6 +147,7 @@ const MappedInstances = ({
         pageSize={limit}
         setSortBy={setSortBy}
         isLoading={isLoading}
+        selectRows={selectRows}
       />
     </Box>
   );
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/index.jsx b/airflow/www/static/js/tree/details/content/taskInstance/index.jsx
index d8b71cb128..62ffee156a 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/index.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/index.jsx
@@ -17,12 +17,14 @@
  * under the License.
  */
 
-import React from 'react';
+import React, { useState } from 'react';
 import {
   Box,
   VStack,
   Divider,
   StackDivider,
+  Text,
+  Flex,
 } from '@chakra-ui/react';
 
 import RunAction from './taskActions/Run';
@@ -54,6 +56,7 @@ const getTask = ({ taskId, runId, task }) => {
 };
 
 const TaskInstance = ({ taskId, runId }) => {
+  const [selectedRows, setSelectedRows] = useState([]);
   const { data: { groups = {}, dagRuns = [] } } = useTreeData();
   const group = getTask({ taskId, runId, task: groups });
   const run = dagRuns.find((r) => r.runId === runId);
@@ -68,6 +71,11 @@ const TaskInstance = ({ taskId, runId }) => {
 
   const instance = group.instances.find((ti) => ti.runId === runId);
 
+  let taskActionsTitle = 'Task Actions';
+  if (isMapped) {
+    taskActionsTitle += ` for ${selectedRows.length || 'all'} mapped task${selectedRows.length !== 1 ? 's' : ''}`;
+  }
+
   return (
     <Box py="4px">
       {!isGroup && (
@@ -80,27 +88,40 @@ const TaskInstance = ({ taskId, runId }) => {
       )}
       {!isGroup && (
         <Box my={3}>
+          <Text as="strong">{taskActionsTitle}</Text>
+          <Flex maxHeight="20px" minHeight="20px">
+            {selectedRows.length ? (
+              <Text color="red.500">
+                Clear, Mark Failed, and Mark Success do not yet work with individual mapped tasks.
+              </Text>
+            ) : <Divider my={2} />}
+          </Flex>
+          {/* visibility={selectedRows.length ? 'visible' : 'hidden'} */}
           <VStack justifyContent="center" divider={<StackDivider my={3} />}>
             <RunAction
               runId={runId}
               taskId={taskId}
               dagId={dagId}
+              selectedRows={selectedRows}
             />
             <ClearAction
               runId={runId}
               taskId={taskId}
               dagId={dagId}
               executionDate={executionDate}
+              selectedRows={selectedRows}
             />
             <MarkFailedAction
               runId={runId}
               taskId={taskId}
               dagId={dagId}
+              selectedRows={selectedRows}
             />
             <MarkSuccessAction
               runId={runId}
               taskId={taskId}
               dagId={dagId}
+              selectedRows={selectedRows}
             />
           </VStack>
           <Divider my={2} />
@@ -122,7 +143,7 @@ const TaskInstance = ({ taskId, runId }) => {
         extraLinks={extraLinks}
       />
       {isMapped && (
-        <MappedInstances dagId={dagId} runId={runId} taskId={taskId} />
+        <MappedInstances dagId={dagId} runId={runId} taskId={taskId} selectRows={setSelectedRows} />
       )}
     </Box>
   );
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Clear.jsx b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Clear.jsx
index 4196edc6b9..cada7b59ed 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Clear.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Clear.jsx
@@ -34,6 +34,7 @@ const Run = ({
   runId,
   taskId,
   executionDate,
+  selectedRows,
 }) => {
   const [affectedTasks, setAffectedTasks] = useState([]);
 
@@ -113,6 +114,7 @@ const Run = ({
         colorScheme="blue"
         onClick={onClick}
         isLoading={isLoading}
+        isDisabled={!!selectedRows.length}
         title="Clearing deletes the previous state of the task instance, allowing it to get re-triggered by the scheduler or a backfill command"
       >
         Clear
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkFailed.jsx b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkFailed.jsx
index fe277c9eef..6bc10c066e 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkFailed.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkFailed.jsx
@@ -33,6 +33,7 @@ const MarkFailed = ({
   dagId,
   runId,
   taskId,
+  selectedRows,
 }) => {
   const [affectedTasks, setAffectedTasks] = useState([]);
 
@@ -99,7 +100,7 @@ const MarkFailed = ({
         <ActionButton bg={upstream && 'gray.100'} onClick={onToggleUpstream} name="Upstream" />
         <ActionButton bg={downstream && 'gray.100'} onClick={onToggleDownstream} name="Downstream" />
       </ButtonGroup>
-      <Button colorScheme="red" onClick={onClick} isLoading={isMarkLoading || isConfirmLoading}>
+      <Button colorScheme="red" onClick={onClick} isLoading={isMarkLoading || isConfirmLoading} isDisabled={!!selectedRows.length}>
         Mark Failed
       </Button>
       <ConfirmDialog
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkSuccess.jsx b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkSuccess.jsx
index 06bc80c756..b4d2b8c047 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkSuccess.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkSuccess.jsx
@@ -30,7 +30,7 @@ import ActionButton from './ActionButton';
 import { useMarkSuccessTask, useConfirmMarkTask } from '../../../../api';
 
 const Run = ({
-  dagId, runId, taskId,
+  dagId, runId, taskId, selectedRows,
 }) => {
   const [affectedTasks, setAffectedTasks] = useState([]);
 
@@ -95,7 +95,7 @@ const Run = ({
         <ActionButton bg={upstream && 'gray.100'} onClick={onToggleUpstream} name="Upstream" />
         <ActionButton bg={downstream && 'gray.100'} onClick={onToggleDownstream} name="Downstream" />
       </ButtonGroup>
-      <Button colorScheme="green" onClick={onClick} isLoading={isMarkLoading || isConfirmLoading}>
+      <Button colorScheme="green" onClick={onClick} isLoading={isMarkLoading || isConfirmLoading} isDisabled={!!selectedRows.length}>
         Mark Success
       </Button>
       <ConfirmDialog
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Run.jsx b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Run.jsx
index 204cec44c2..41c8bb9c6f 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Run.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Run.jsx
@@ -30,6 +30,7 @@ const Run = ({
   dagId,
   runId,
   taskId,
+  selectedRows,
 }) => {
   const [ignoreAllDeps, setIgnoreAllDeps] = useState(false);
   const onToggleAllDeps = () => setIgnoreAllDeps(!ignoreAllDeps);
@@ -43,11 +44,22 @@ const Run = ({
   const { mutate: onRun, isLoading } = useRunTask(dagId, runId, taskId);
 
   const onClick = () => {
-    onRun({
-      ignoreAllDeps,
-      ignoreTaskState,
-      ignoreTaskDeps,
-    });
+    if (selectedRows.length) {
+      selectedRows.forEach((mapIndex) => {
+        onRun({
+          ignoreAllDeps,
+          ignoreTaskState,
+          ignoreTaskDeps,
+          mapIndex,
+        });
+      });
+    } else {
+      onRun({
+        ignoreAllDeps,
+        ignoreTaskState,
+        ignoreTaskDeps,
+      });
+    }
   };
 
   return (


[airflow] 10/19: Refactor to straighten up types

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3b12a914db3962b40141498f9eb43af91ecd80b2
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Mon Apr 18 22:13:10 2022 +0800

    Refactor to straighten up types
---
 airflow/models/dag.py | 122 +++++++++++++++++++++-----------------------------
 1 file changed, 52 insertions(+), 70 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 0694f37550..755505b5d0 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -39,6 +39,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Type,
@@ -1340,40 +1341,29 @@ class DAG(LoggingMixin):
             start_date = (timezone.utcnow() - timedelta(30)).replace(
                 hour=0, minute=0, second=0, microsecond=0
             )
-
-        if state is None:
-            state = []
-
-        return (
-            cast(
-                Query,
-                self._get_task_instances(
-                    task_ids=None,
-                    start_date=start_date,
-                    end_date=end_date,
-                    run_id=None,
-                    state=state,
-                    include_subdags=False,
-                    include_parentdag=False,
-                    include_dependent_dags=False,
-                    exclude_task_ids=cast(List[str], []),
-                    session=session,
-                ),
-            )
-            .order_by(DagRun.execution_date)
-            .all()
+        query = self._get_task_instances(
+            task_ids=None,
+            start_date=start_date,
+            end_date=end_date,
+            run_id=None,
+            state=state or (),
+            include_subdags=False,
+            include_parentdag=False,
+            include_dependent_dags=False,
+            exclude_task_ids=(),
+            session=session,
         )
+        return cast(Query, query).order_by(DagRun.execution_date).all()
 
     @overload
     def _get_task_instances(
         self,
         *,
         task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
-        task_ids_and_map_indexes,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, List[TaskInstanceState]],
+        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
@@ -1392,7 +1382,7 @@ class DAG(LoggingMixin):
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, List[TaskInstanceState]],
+        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
@@ -1413,7 +1403,7 @@ class DAG(LoggingMixin):
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, List[TaskInstanceState]],
+        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
@@ -1441,18 +1431,6 @@ class DAG(LoggingMixin):
             tis = session.query(TaskInstance)
         tis = tis.join(TaskInstance.dag_run)
 
-        task_ids_and_map_indexes = None
-        if task_ids is not None:
-            task_ids_and_map_indexes = [item for item in task_ids if isinstance(item, tuple)]
-        if task_ids_and_map_indexes:
-            task_ids = None  # nullify since we have indexes
-
-        exclude_task_ids_and_map_indexes = None
-        if exclude_task_ids is not None:
-            exclude_task_ids_and_map_indexes = [item for item in exclude_task_ids if isinstance(item, tuple)]
-        if exclude_task_ids_and_map_indexes:
-            exclude_task_ids = None
-
         if include_subdags:
             # Crafting the right filter for dag_id and task_ids combo
             conditions = []
@@ -1467,12 +1445,13 @@ class DAG(LoggingMixin):
             tis = tis.filter(TaskInstance.run_id == run_id)
         if start_date:
             tis = tis.filter(DagRun.execution_date >= start_date)
-        if task_ids:
-            tis = tis.filter(TaskInstance.task_id.in_(task_ids))
-        if task_ids_and_map_indexes:
-            tis = tis.filter(
-                tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids_and_map_indexes)
-            )
+
+        if task_ids is None:
+            pass  # Disable filter if not set.
+        elif isinstance(next(iter(task_ids), None), str):
+            tis = tis.filter(TI.task_id.in_(task_ids))
+        else:
+            tis = tis.filter(tuple_(TI.task_id, TI.map_index).in_(task_ids))
 
         # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
         if end_date or not self.allow_future_exec_dates:
@@ -1610,33 +1589,29 @@ class DAG(LoggingMixin):
             if as_pk_tuple:
                 result.update(TaskInstanceKey(*cols) for cols in tis.all())
             else:
-                result.update(ti.key for ti in tis.all())
+                result.update(ti.key for ti in tis)
 
             if exclude_task_ids is not None:
-                result = set(
-                    filter(
-                        lambda key: key.task_id not in exclude_task_ids,
-                        result,
-                    )
-                )
-
-            if exclude_task_ids_and_map_indexes is not None:
-                result = set(
-                    filter(
-                        lambda key: (key.task_id, key.map_index) not in exclude_task_ids_and_map_indexes,
-                        result,
-                    )
-                )
+                result = {
+                    task
+                    for task in result
+                    if task.task_id not in exclude_task_ids
+                    and (task.task_id, task.map_index) not in exclude_task_ids
+                }
 
         if as_pk_tuple:
             return result
-        elif result:
+        if result:
             # We've been asked for objects, lets combine it all back in to a result set
-            tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
-
-            tis = session.query(TI).filter(TI.filter_for_tis(result))
-        elif exclude_task_ids_and_map_indexes:
-            tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids_and_map_indexes))
+            ti_filters = TI.filter_for_tis(result)
+            if ti_filters is not None:
+                tis = session.query(TI).filter(ti_filters)
+        elif exclude_task_ids is None:
+            pass  # Disable filter if not set.
+        elif isinstance(next(iter(exclude_task_ids), None), str):
+            tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
+        else:
+            tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids))
 
         return tis
 
@@ -1687,11 +1662,18 @@ class DAG(LoggingMixin):
 
         task = self.get_task(task_id)
         task.dag = self
-        task_map_indexes = [(task, map_index)] if map_index else [task]
-        task_id_map_indexes = {(task_id, map_index)} if map_index else {task_id}
+
+        tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]]
+        task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]]
+        if map_index is None:
+            tasks_to_set_state = [task]
+            task_ids_to_exclude_from_clear = {task_id}
+        else:
+            tasks_to_set_state = [(task, map_index)]
+            task_ids_to_exclude_from_clear = {(task_id, map_index)}
 
         altered = set_state(
-            tasks=task_map_indexes,
+            tasks=tasks_to_set_state,
             execution_date=execution_date,
             run_id=run_id,
             upstream=upstream,
@@ -1726,7 +1708,7 @@ class DAG(LoggingMixin):
             only_failed=True,
             session=session,
             # Exclude the task itself from being cleared
-            exclude_task_ids=task_id_map_indexes,
+            exclude_task_ids=task_ids_to_exclude_from_clear,
         )
 
         return altered
@@ -1784,7 +1766,7 @@ class DAG(LoggingMixin):
     @provide_session
     def clear(
         self,
-        task_ids: Union[Iterable[str], Iterable[Tuple[str, int]], None] = None,
+        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None,
         start_date: Optional[datetime] = None,
         end_date: Optional[datetime] = None,
         only_failed: bool = False,


[airflow] 12/19: Introduce tuple_().in_() shim for MSSQL compat

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 469092494da6b8baa6cfe145b76e40eaa495635e
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Apr 19 18:01:55 2022 +0800

    Introduce tuple_().in_() shim for MSSQL compat
---
 airflow/api/common/mark_tasks.py | 5 +++--
 airflow/models/dag.py            | 8 ++++----
 airflow/utils/sqlalchemy.py      | 8 +++-----
 3 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 349b935e82..1d4709fb82 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -20,7 +20,7 @@
 from datetime import datetime
 from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
 
-from sqlalchemy import or_, tuple_
+from sqlalchemy import or_
 from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm.session import Session as SASession
 
@@ -32,6 +32,7 @@ from airflow.operators.subdag import SubDagOperator
 from airflow.utils import timezone
 from airflow.utils.helpers import exactly_one
 from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
@@ -203,7 +204,7 @@ def get_all_dag_task_query(
     if is_string_list:
         qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids))
     else:
-        qry_dag = qry_dag.filter(tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids))
+        qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids))
     qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
         contains_eager(TaskInstance.dag_run)
     )
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 9c93bcef13..83860ba591 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -52,7 +52,7 @@ import jinja2
 import pendulum
 from dateutil.relativedelta import relativedelta
 from pendulum.tz.timezone import Timezone
-from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_, tuple_
+from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_
 from sqlalchemy.orm import backref, joinedload, relationship
 from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.session import Session
@@ -85,7 +85,7 @@ from airflow.utils.file import correct_maybe_zipped
 from airflow.utils.helpers import exactly_one, validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
+from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
 
@@ -1451,7 +1451,7 @@ class DAG(LoggingMixin):
         elif isinstance(next(iter(task_ids), None), str):
             tis = tis.filter(TI.task_id.in_(task_ids))
         else:
-            tis = tis.filter(tuple_(TI.task_id, TI.map_index).in_(task_ids))
+            tis = tis.filter(tuple_in_condition((TI.task_id, TI.map_index), task_ids))
 
         # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
         if end_date or not self.allow_future_exec_dates:
@@ -1611,7 +1611,7 @@ class DAG(LoggingMixin):
         elif isinstance(next(iter(exclude_task_ids), None), str):
             tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
         else:
-            tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids))
+            tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids)))
 
         return tis
 
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index de4ad01e69..5c36d826b2 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -19,11 +19,12 @@
 import datetime
 import json
 import logging
+from operator import and_, or_
 from typing import Any, Dict, Iterable, Tuple
 
 import pendulum
 from dateutil import relativedelta
-from sqlalchemy import and_, event, false, nullsfirst, or_, tuple_
+from sqlalchemy import event, nullsfirst, tuple_
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm.session import Session
 from sqlalchemy.sql import ColumnElement
@@ -338,7 +339,4 @@ def tuple_in_condition(
     """
     if settings.engine.dialect.name != "mssql":
         return tuple_(*columns).in_(collection)
-    clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in collection]
-    if not clauses:
-        return false()
-    return or_(*clauses)
+    return or_(*(and_(*(c == v for c, v in zip(columns, values))) for values in collection))


[airflow] 18/19: Chain map_index params

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit f489b107ec82b4d07ac6928e102bbceb9750629f
Author: Brent Bovenzi <br...@gmail.com>
AuthorDate: Tue Apr 19 09:57:37 2022 -0400

    Chain map_index params
---
 airflow/www/static/js/tree/api/useClearTask.js       | 9 ++++++---
 airflow/www/static/js/tree/api/useConfirmMarkTask.js | 6 +++++-
 airflow/www/static/js/tree/api/useMarkFailedTask.js  | 9 ++++++---
 airflow/www/static/js/tree/api/useMarkSuccessTask.js | 8 ++++++--
 4 files changed, 23 insertions(+), 9 deletions(-)

diff --git a/airflow/www/static/js/tree/api/useClearTask.js b/airflow/www/static/js/tree/api/useClearTask.js
index eea4b2b656..2ea3eee486 100644
--- a/airflow/www/static/js/tree/api/useClearTask.js
+++ b/airflow/www/static/js/tree/api/useClearTask.js
@@ -51,10 +51,13 @@ export default function useClearTask({
         downstream,
         recursive,
         only_failed: failed,
-        map_indexes: mapIndexes,
-      }).toString();
+      });
+
+      mapIndexes.forEach((mi) => {
+        params.append('map_index', mi);
+      });
 
-      return axios.post(clearUrl, params, {
+      return axios.post(clearUrl, params.toString(), {
         headers: {
           'Content-Type': 'application/x-www-form-urlencoded',
         },
diff --git a/airflow/www/static/js/tree/api/useConfirmMarkTask.js b/airflow/www/static/js/tree/api/useConfirmMarkTask.js
index 85b5f7df42..d1f8eef9d3 100644
--- a/airflow/www/static/js/tree/api/useConfirmMarkTask.js
+++ b/airflow/www/static/js/tree/api/useConfirmMarkTask.js
@@ -40,8 +40,12 @@ export default function useConfirmMarkTask({
         upstream,
         downstream,
         state,
-        map_indexes: mapIndexes,
       });
+
+      mapIndexes.forEach((mi) => {
+        params.append('map_index', mi);
+      });
+
       return axios.get(confirmUrl, { params });
     },
   );
diff --git a/airflow/www/static/js/tree/api/useMarkFailedTask.js b/airflow/www/static/js/tree/api/useMarkFailedTask.js
index 333fed21c2..a94ab22d0c 100644
--- a/airflow/www/static/js/tree/api/useMarkFailedTask.js
+++ b/airflow/www/static/js/tree/api/useMarkFailedTask.js
@@ -45,10 +45,13 @@ export default function useMarkFailedTask({
         future,
         upstream,
         downstream,
-        map_indexes: mapIndexes,
-      }).toString();
+      });
+
+      mapIndexes.forEach((mi) => {
+        params.append('map_index', mi);
+      });
 
-      return axios.post(failedUrl, params, {
+      return axios.post(failedUrl, params.toString(), {
         headers: {
           'Content-Type': 'application/x-www-form-urlencoded',
         },
diff --git a/airflow/www/static/js/tree/api/useMarkSuccessTask.js b/airflow/www/static/js/tree/api/useMarkSuccessTask.js
index cde919a274..47fda2f0f8 100644
--- a/airflow/www/static/js/tree/api/useMarkSuccessTask.js
+++ b/airflow/www/static/js/tree/api/useMarkSuccessTask.js
@@ -46,9 +46,13 @@ export default function useMarkSuccessTask({
         upstream,
         downstream,
         map_indexes: mapIndexes,
-      }).toString();
+      });
+
+      mapIndexes.forEach((mi) => {
+        params.append('map_index', mi);
+      });
 
-      return axios.post(successUrl, params, {
+      return axios.post(successUrl, params.toString(), {
         headers: {
           'Content-Type': 'application/x-www-form-urlencoded',
         },


[airflow] 16/19: Readd mapped instance table selection

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit acb534127ba293e97c1073e9b3581715c8a8f03f
Author: Brent Bovenzi <br...@gmail.com>
AuthorDate: Tue Apr 12 14:51:45 2022 -0400

    Readd mapped instance table selection
---
 airflow/www/static/js/tree/Table.jsx               |  4 +-
 airflow/www/static/js/tree/api/useClearTask.js     |  4 +-
 .../www/static/js/tree/api/useConfirmMarkTask.js   | 12 +++---
 .../www/static/js/tree/api/useMarkFailedTask.js    |  4 +-
 .../www/static/js/tree/api/useMarkSuccessTask.js   |  4 +-
 airflow/www/static/js/tree/api/useRunTask.js       | 43 ++++++++++++----------
 .../js/tree/details/content/taskInstance/index.jsx | 25 ++++++-------
 .../content/taskInstance/taskActions/Clear.jsx     |  5 ++-
 .../taskInstance/taskActions/MarkFailed.jsx        |  6 ++-
 .../taskInstance/taskActions/MarkSuccess.jsx       | 10 +++--
 .../content/taskInstance/taskActions/Run.jsx       | 24 ++++--------
 11 files changed, 72 insertions(+), 69 deletions(-)

diff --git a/airflow/www/static/js/tree/Table.jsx b/airflow/www/static/js/tree/Table.jsx
index 152f647ea3..e500b08966 100644
--- a/airflow/www/static/js/tree/Table.jsx
+++ b/airflow/www/static/js/tree/Table.jsx
@@ -46,7 +46,7 @@ import {
 } from 'react-icons/ti';
 
 const IndeterminateCheckbox = forwardRef(
-  ({ indeterminate, ...rest }, ref) => {
+  ({ indeterminate, checked, ...rest }, ref) => {
     const defaultRef = useRef();
     const resolvedRef = ref || defaultRef;
 
@@ -55,7 +55,7 @@ const IndeterminateCheckbox = forwardRef(
     }, [resolvedRef, indeterminate]);
 
     return (
-      <Checkbox ref={resolvedRef} {...rest} />
+      <Checkbox ref={resolvedRef} isChecked={checked} {...rest} />
     );
   },
 );
diff --git a/airflow/www/static/js/tree/api/useClearTask.js b/airflow/www/static/js/tree/api/useClearTask.js
index bcf99bb250..eea4b2b656 100644
--- a/airflow/www/static/js/tree/api/useClearTask.js
+++ b/airflow/www/static/js/tree/api/useClearTask.js
@@ -36,7 +36,7 @@ export default function useClearTask({
   return useMutation(
     ['clearTask', dagId, runId, taskId],
     ({
-      past, future, upstream, downstream, recursive, failed, confirmed,
+      past, future, upstream, downstream, recursive, failed, confirmed, mapIndexes = [],
     }) => {
       const params = new URLSearchParams({
         csrf_token: csrfToken,
@@ -51,6 +51,7 @@ export default function useClearTask({
         downstream,
         recursive,
         only_failed: failed,
+        map_indexes: mapIndexes,
       }).toString();
 
       return axios.post(clearUrl, params, {
@@ -71,6 +72,7 @@ export default function useClearTask({
         }
         if (!status || status !== 'error') {
           queryClient.invalidateQueries('treeData');
+          queryClient.invalidateQueries('mappedInstances', dagId, runId, taskId);
           startRefresh();
         }
       },
diff --git a/airflow/www/static/js/tree/api/useConfirmMarkTask.js b/airflow/www/static/js/tree/api/useConfirmMarkTask.js
index 1450a15d3d..85b5f7df42 100644
--- a/airflow/www/static/js/tree/api/useConfirmMarkTask.js
+++ b/airflow/www/static/js/tree/api/useConfirmMarkTask.js
@@ -29,9 +29,9 @@ export default function useConfirmMarkTask({
   return useMutation(
     ['confirmStateChange', dagId, runId, taskId, state],
     ({
-      past, future, upstream, downstream,
-    }) => axios.get(confirmUrl, {
-      params: {
+      past, future, upstream, downstream, mapIndexes = [],
+    }) => {
+      const params = new URLSearchParams({
         dag_id: dagId,
         dag_run_id: runId,
         task_id: taskId,
@@ -40,7 +40,9 @@ export default function useConfirmMarkTask({
         upstream,
         downstream,
         state,
-      },
-    }),
+        map_indexes: mapIndexes,
+      });
+      return axios.get(confirmUrl, { params });
+    },
   );
 }
diff --git a/airflow/www/static/js/tree/api/useMarkFailedTask.js b/airflow/www/static/js/tree/api/useMarkFailedTask.js
index f2fd28bdb2..333fed21c2 100644
--- a/airflow/www/static/js/tree/api/useMarkFailedTask.js
+++ b/airflow/www/static/js/tree/api/useMarkFailedTask.js
@@ -33,7 +33,7 @@ export default function useMarkFailedTask({
   return useMutation(
     ['markFailed', dagId, runId, taskId],
     ({
-      past, future, upstream, downstream,
+      past, future, upstream, downstream, mapIndexes = [],
     }) => {
       const params = new URLSearchParams({
         csrf_token: csrfToken,
@@ -45,6 +45,7 @@ export default function useMarkFailedTask({
         future,
         upstream,
         downstream,
+        map_indexes: mapIndexes,
       }).toString();
 
       return axios.post(failedUrl, params, {
@@ -56,6 +57,7 @@ export default function useMarkFailedTask({
     {
       onSuccess: () => {
         queryClient.invalidateQueries('treeData');
+        queryClient.invalidateQueries('mappedInstances', dagId, runId, taskId);
         startRefresh();
       },
     },
diff --git a/airflow/www/static/js/tree/api/useMarkSuccessTask.js b/airflow/www/static/js/tree/api/useMarkSuccessTask.js
index 92ba539de6..cde919a274 100644
--- a/airflow/www/static/js/tree/api/useMarkSuccessTask.js
+++ b/airflow/www/static/js/tree/api/useMarkSuccessTask.js
@@ -33,7 +33,7 @@ export default function useMarkSuccessTask({
   return useMutation(
     ['markSuccess', dagId, runId, taskId],
     ({
-      past, future, upstream, downstream,
+      past, future, upstream, downstream, mapIndexes = [],
     }) => {
       const params = new URLSearchParams({
         csrf_token: csrfToken,
@@ -45,6 +45,7 @@ export default function useMarkSuccessTask({
         future,
         upstream,
         downstream,
+        map_indexes: mapIndexes,
       }).toString();
 
       return axios.post(successUrl, params, {
@@ -56,6 +57,7 @@ export default function useMarkSuccessTask({
     {
       onSuccess: () => {
         queryClient.invalidateQueries('treeData');
+        queryClient.invalidateQueries('mappedInstances', dagId, runId, taskId);
         startRefresh();
       },
     },
diff --git a/airflow/www/static/js/tree/api/useRunTask.js b/airflow/www/static/js/tree/api/useRunTask.js
index 44a9e14bf4..9e45c42f59 100644
--- a/airflow/www/static/js/tree/api/useRunTask.js
+++ b/airflow/www/static/js/tree/api/useRunTask.js
@@ -32,32 +32,34 @@ export default function useRunTask(dagId, runId, taskId) {
   const { startRefresh } = useAutoRefresh();
   return useMutation(
     ['runTask', dagId, runId, taskId],
-    ({
+    async ({
       ignoreAllDeps,
       ignoreTaskState,
       ignoreTaskDeps,
-      mapIndex = -1,
-    }) => {
-      const params = new URLSearchParams({
-        csrf_token: csrfToken,
-        dag_id: dagId,
-        dag_run_id: runId,
-        task_id: taskId,
-        ignore_all_deps: ignoreAllDeps,
-        ignore_task_deps: ignoreTaskDeps,
-        ignore_ti_state: ignoreTaskState,
-        map_index: mapIndex,
-      }).toString();
+      mapIndexes,
+    }) => Promise.all(
+      (mapIndexes.length ? mapIndexes : [-1]).map((mi) => {
+        const params = new URLSearchParams({
+          csrf_token: csrfToken,
+          dag_id: dagId,
+          dag_run_id: runId,
+          task_id: taskId,
+          ignore_all_deps: ignoreAllDeps,
+          ignore_task_deps: ignoreTaskDeps,
+          ignore_ti_state: ignoreTaskState,
+          map_index: mi,
+        }).toString();
 
-      return axios.post(runUrl, params, {
-        headers: {
-          'Content-Type': 'application/x-www-form-urlencoded',
-        },
-      });
-    },
+        return axios.post(runUrl, params, {
+          headers: {
+            'Content-Type': 'application/x-www-form-urlencoded',
+          },
+        });
+      }),
+    ),
     {
       onSuccess: (data) => {
-        const { message, status } = data;
+        const { message, status } = data.length ? data[0] : data;
         if (message && status === 'error') {
           toast({
             description: message,
@@ -67,6 +69,7 @@ export default function useRunTask(dagId, runId, taskId) {
         }
         if (!status || status !== 'error') {
           queryClient.invalidateQueries('treeData');
+          queryClient.invalidateQueries('mappedInstances', dagId, runId, taskId);
           startRefresh();
         }
       },
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/index.jsx b/airflow/www/static/js/tree/details/content/taskInstance/index.jsx
index 62ffee156a..0e4f441e24 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/index.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/index.jsx
@@ -24,7 +24,6 @@ import {
   Divider,
   StackDivider,
   Text,
-  Flex,
 } from '@chakra-ui/react';
 
 import RunAction from './taskActions/Run';
@@ -89,39 +88,32 @@ const TaskInstance = ({ taskId, runId }) => {
       {!isGroup && (
         <Box my={3}>
           <Text as="strong">{taskActionsTitle}</Text>
-          <Flex maxHeight="20px" minHeight="20px">
-            {selectedRows.length ? (
-              <Text color="red.500">
-                Clear, Mark Failed, and Mark Success do not yet work with individual mapped tasks.
-              </Text>
-            ) : <Divider my={2} />}
-          </Flex>
-          {/* visibility={selectedRows.length ? 'visible' : 'hidden'} */}
+          <Divider my={2} />
           <VStack justifyContent="center" divider={<StackDivider my={3} />}>
             <RunAction
               runId={runId}
               taskId={taskId}
               dagId={dagId}
-              selectedRows={selectedRows}
+              mapIndexes={selectedRows}
             />
             <ClearAction
               runId={runId}
               taskId={taskId}
               dagId={dagId}
               executionDate={executionDate}
-              selectedRows={selectedRows}
+              mapIndexes={selectedRows}
             />
             <MarkFailedAction
               runId={runId}
               taskId={taskId}
               dagId={dagId}
-              selectedRows={selectedRows}
+              mapIndexes={selectedRows}
             />
             <MarkSuccessAction
               runId={runId}
               taskId={taskId}
               dagId={dagId}
-              selectedRows={selectedRows}
+              mapIndexes={selectedRows}
             />
           </VStack>
           <Divider my={2} />
@@ -143,7 +135,12 @@ const TaskInstance = ({ taskId, runId }) => {
         extraLinks={extraLinks}
       />
       {isMapped && (
-        <MappedInstances dagId={dagId} runId={runId} taskId={taskId} selectRows={setSelectedRows} />
+        <MappedInstances
+          dagId={dagId}
+          runId={runId}
+          taskId={taskId}
+          selectRows={setSelectedRows}
+        />
       )}
     </Box>
   );
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Clear.jsx b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Clear.jsx
index cada7b59ed..d825976ed2 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Clear.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Clear.jsx
@@ -34,7 +34,7 @@ const Run = ({
   runId,
   taskId,
   executionDate,
-  selectedRows,
+  mapIndexes,
 }) => {
   const [affectedTasks, setAffectedTasks] = useState([]);
 
@@ -74,6 +74,7 @@ const Run = ({
         recursive,
         failed,
         confirmed: false,
+        mapIndexes,
       });
       setAffectedTasks(data);
       onOpen();
@@ -92,6 +93,7 @@ const Run = ({
         recursive,
         failed,
         confirmed: true,
+        mapIndexes,
       });
       setAffectedTasks([]);
       onClose();
@@ -114,7 +116,6 @@ const Run = ({
         colorScheme="blue"
         onClick={onClick}
         isLoading={isLoading}
-        isDisabled={!!selectedRows.length}
         title="Clearing deletes the previous state of the task instance, allowing it to get re-triggered by the scheduler or a backfill command"
       >
         Clear
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkFailed.jsx b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkFailed.jsx
index 6bc10c066e..12f8bcfeef 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkFailed.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkFailed.jsx
@@ -33,7 +33,7 @@ const MarkFailed = ({
   dagId,
   runId,
   taskId,
-  selectedRows,
+  mapIndexes,
 }) => {
   const [affectedTasks, setAffectedTasks] = useState([]);
 
@@ -69,6 +69,7 @@ const MarkFailed = ({
         future,
         upstream,
         downstream,
+        mapIndexes,
       });
       setAffectedTasks(data);
       onOpen();
@@ -84,6 +85,7 @@ const MarkFailed = ({
         future,
         upstream,
         downstream,
+        mapIndexes,
       });
       setAffectedTasks([]);
       onClose();
@@ -100,7 +102,7 @@ const MarkFailed = ({
         <ActionButton bg={upstream && 'gray.100'} onClick={onToggleUpstream} name="Upstream" />
         <ActionButton bg={downstream && 'gray.100'} onClick={onToggleDownstream} name="Downstream" />
       </ButtonGroup>
-      <Button colorScheme="red" onClick={onClick} isLoading={isMarkLoading || isConfirmLoading} isDisabled={!!selectedRows.length}>
+      <Button colorScheme="red" onClick={onClick} isLoading={isMarkLoading || isConfirmLoading}>
         Mark Failed
       </Button>
       <ConfirmDialog
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkSuccess.jsx b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkSuccess.jsx
index b4d2b8c047..bdf59e6a4d 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkSuccess.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/MarkSuccess.jsx
@@ -29,8 +29,8 @@ import ConfirmDialog from '../../ConfirmDialog';
 import ActionButton from './ActionButton';
 import { useMarkSuccessTask, useConfirmMarkTask } from '../../../../api';
 
-const Run = ({
-  dagId, runId, taskId, selectedRows,
+const MarkSuccess = ({
+  dagId, runId, taskId, mapIndexes,
 }) => {
   const [affectedTasks, setAffectedTasks] = useState([]);
 
@@ -64,6 +64,7 @@ const Run = ({
         future,
         upstream,
         downstream,
+        mapIndexes,
       });
       setAffectedTasks(data);
       onOpen();
@@ -79,6 +80,7 @@ const Run = ({
         future,
         upstream,
         downstream,
+        mapIndexes,
       });
       setAffectedTasks([]);
       onClose();
@@ -95,7 +97,7 @@ const Run = ({
         <ActionButton bg={upstream && 'gray.100'} onClick={onToggleUpstream} name="Upstream" />
         <ActionButton bg={downstream && 'gray.100'} onClick={onToggleDownstream} name="Downstream" />
       </ButtonGroup>
-      <Button colorScheme="green" onClick={onClick} isLoading={isMarkLoading || isConfirmLoading} isDisabled={!!selectedRows.length}>
+      <Button colorScheme="green" onClick={onClick} isLoading={isMarkLoading || isConfirmLoading}>
         Mark Success
       </Button>
       <ConfirmDialog
@@ -109,4 +111,4 @@ const Run = ({
   );
 };
 
-export default Run;
+export default MarkSuccess;
diff --git a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Run.jsx b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Run.jsx
index 41c8bb9c6f..85d502aeed 100644
--- a/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Run.jsx
+++ b/airflow/www/static/js/tree/details/content/taskInstance/taskActions/Run.jsx
@@ -30,7 +30,7 @@ const Run = ({
   dagId,
   runId,
   taskId,
-  selectedRows,
+  mapIndexes,
 }) => {
   const [ignoreAllDeps, setIgnoreAllDeps] = useState(false);
   const onToggleAllDeps = () => setIgnoreAllDeps(!ignoreAllDeps);
@@ -44,22 +44,12 @@ const Run = ({
   const { mutate: onRun, isLoading } = useRunTask(dagId, runId, taskId);
 
   const onClick = () => {
-    if (selectedRows.length) {
-      selectedRows.forEach((mapIndex) => {
-        onRun({
-          ignoreAllDeps,
-          ignoreTaskState,
-          ignoreTaskDeps,
-          mapIndex,
-        });
-      });
-    } else {
-      onRun({
-        ignoreAllDeps,
-        ignoreTaskState,
-        ignoreTaskDeps,
-      });
-    }
+    onRun({
+      ignoreAllDeps,
+      ignoreTaskState,
+      ignoreTaskDeps,
+      mapIndexes,
+    });
   };
 
   return (


[airflow] 06/19: fixup! fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7ea4ace57dcd80920f63a9cd439eed53da5d67da
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Apr 13 21:03:34 2022 +0100

    fixup! fixup! fixup! fixup! fixup! Allow marking/clearing mapped taskinstances from the UI
---
 airflow/api/common/mark_tasks.py |  3 +--
 airflow/models/dag.py            | 21 ++++++++++++---------
 2 files changed, 13 insertions(+), 11 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 84fd48f4e4..8885ebb59e 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -271,8 +271,7 @@ def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str]) -> Iterator[_DagR
         yield _DagRunInfo(dag_run.logical_date, dag.get_run_data_interval(dag_run))
 
 
-@provide_session
-def find_task_relatives(tasks, downstream, upstream, session: SASession = NEW_SESSION):
+def find_task_relatives(tasks, downstream, upstream):
     """Yield task ids and optionally ancestor and descendant ids."""
     for item in tasks:
         if isinstance(item, tuple):
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 931fd469d7..8856a841d3 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1370,8 +1370,8 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids: Iterable[str],
-        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
+        task_ids,
+        task_ids_and_map_indexes,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
@@ -1380,7 +1380,7 @@ class DAG(LoggingMixin):
         include_parentdag: bool,
         include_dependent_dags: bool,
         exclude_task_ids: Collection[str],
-        exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
+        exclude_task_ids_and_map_indexes,
         session: Session,
         dag_bag: Optional["DagBag"] = ...,
     ) -> Iterable[TaskInstance]:
@@ -1390,8 +1390,8 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids: Iterable[str],
-        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
+        task_ids,
+        task_ids_and_map_indexes,
         as_pk_tuple: Literal[True],
         start_date: Optional[datetime],
         end_date: Optional[datetime],
@@ -1401,7 +1401,7 @@ class DAG(LoggingMixin):
         include_parentdag: bool,
         include_dependent_dags: bool,
         exclude_task_ids: Collection[str],
-        exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
+        exclude_task_ids_and_map_indexes,
         session: Session,
         dag_bag: Optional["DagBag"] = ...,
         recursion_depth: int = ...,
@@ -1413,8 +1413,8 @@ class DAG(LoggingMixin):
     def _get_task_instances(
         self,
         *,
-        task_ids: Iterable[str],
-        task_ids_and_map_indexes: Optional[Iterable[Tuple[str, int]]],
+        task_ids,
+        task_ids_and_map_indexes,
         as_pk_tuple: Literal[True, None] = None,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
@@ -1424,7 +1424,7 @@ class DAG(LoggingMixin):
         include_parentdag: bool,
         include_dependent_dags: bool,
         exclude_task_ids: Collection[str],
-        exclude_task_ids_and_map_indexes: Collection[Tuple[str, int]],
+        exclude_task_ids_and_map_indexes,
         session: Session,
         dag_bag: Optional["DagBag"] = None,
         recursion_depth: int = 0,
@@ -1587,6 +1587,7 @@ class DAG(LoggingMixin):
                     )
                     result.update(
                         downstream._get_task_instances(
+                            task_ids=None,
                             task_ids_and_map_indexes=None,
                             run_id=tii.run_id,
                             start_date=None,
@@ -1596,6 +1597,7 @@ class DAG(LoggingMixin):
                             include_dependent_dags=include_dependent_dags,
                             include_parentdag=False,
                             as_pk_tuple=True,
+                            exclude_task_ids=exclude_task_ids,
                             exclude_task_ids_and_map_indexes=exclude_task_ids_and_map_indexes,
                             dag_bag=dag_bag,
                             session=session,
@@ -1800,6 +1802,7 @@ class DAG(LoggingMixin):
         Clears a set of task instances associated with the current dag for
         a specified date range.
 
+        :param task_ids: List of task ids to clear
         :param task_ids_and_map_indexes: List of tuple of task_id, map_index to clear
         :param start_date: The minimum execution_date to clear
         :param end_date: The maximum execution_date to clear


[airflow] 17/19: Fix gantt/graph modal

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 6b7b8026b1527164ffce10aca3a96e0ef8ba408d
Author: Brent Bovenzi <br...@gmail.com>
AuthorDate: Tue Apr 12 16:01:31 2022 -0400

    Fix gantt/graph modal
---
 airflow/www/static/js/dag.js           | 15 +++++++--------
 airflow/www/templates/airflow/dag.html |  3 +++
 2 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/airflow/www/static/js/dag.js b/airflow/www/static/js/dag.js
index ca8b7cc676..ded26baeab 100644
--- a/airflow/www/static/js/dag.js
+++ b/airflow/www/static/js/dag.js
@@ -171,14 +171,11 @@ export function callModal({
   $('#extra_links').prev('hr').hide();
   $('#extra_links').empty().hide();
   if (mi >= 0) {
-    // Marking state and clear are not yet supported for mapped instances
-    $('#success_action').hide();
-    $('#failed_action').hide();
-    $('#clear_action').hide();
+    $('#modal_map_index').show();
+    $('#modal_map_index .value').text(mi);
   } else {
-    $('#success_action').show();
-    $('#failed_action').show();
-    $('#clear_action').show();
+    $('#modal_map_index').hide();
+    $('#modal_map_index .value').text('');
   }
   if (isSubDag) {
     $('#div_btn_subdag').show();
@@ -339,7 +336,6 @@ $(document).on('click', '.map_index_item', function mapItem() {
 $('form[data-action]').on('submit', function submit(e) {
   e.preventDefault();
   const form = $(this).get(0);
-  // Somehow submit is fired twice. Only once is the executionDate/dagRunId valid
   if (dagRunId || executionDate) {
     if (form.dag_run_id) {
       form.dag_run_id.value = dagRunId;
@@ -354,6 +350,9 @@ $('form[data-action]').on('submit', function submit(e) {
     if (form.map_index) {
       form.map_index.value = mapIndex === undefined ? '' : mapIndex;
     }
+    if (form.map_indexes) {
+      form.map_indexes.value = mapIndex === undefined ? '' : mapIndex;
+    }
     form.action = $(this).data('action');
     form.submit();
   }
diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html
index 1dafcb27f7..8259c7045c 100644
--- a/airflow/www/templates/airflow/dag.html
+++ b/airflow/www/templates/airflow/dag.html
@@ -310,6 +310,7 @@
             <input type="hidden" name="dag_id" value="{{ dag.dag_id }}">
             <input type="hidden" name="task_id">
             <input type="hidden" name="execution_date">
+            <input type="hidden" name="map_indexes">
             <input type="hidden" name="origin" value="{{ request.base_url }}">
             <div class="row">
               <span class="btn-group col-xs-12 col-sm-9 task-instance-modal-column" data-toggle="buttons">
@@ -351,6 +352,7 @@
             <input type="hidden" name="dag_id" value="{{ dag.dag_id }}">
             <input type="hidden" name="task_id">
             <input type="hidden" name="dag_run_id">
+            <input type="hidden" name="map_indexes">
             <input type="hidden" name="origin" value="{{ request.base_url }}">
             <input type="hidden" name="state" value="failed">
             <div class="row">
@@ -384,6 +386,7 @@
             <input type="hidden" name="dag_id" value="{{ dag.dag_id }}">
             <input type="hidden" name="task_id">
             <input type="hidden" name="dag_run_id">
+            <input type="hidden" name="map_indexes">
             <input type="hidden" name="origin" value="{{ request.base_url }}">
             <input type="hidden" name="state" value="success">
             <div class="row">


[airflow] 14/19: fixup! Introduce tuple_().in_() shim for MSSQL compat

Posted by bb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3d25f2ebba5fe55195f6148dffb61527d9480416
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Wed Apr 20 10:37:53 2022 +0800

    fixup! Introduce tuple_().in_() shim for MSSQL compat
---
 airflow/utils/sqlalchemy.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 5c36d826b2..2838a30d60 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -24,7 +24,7 @@ from typing import Any, Dict, Iterable, Tuple
 
 import pendulum
 from dateutil import relativedelta
-from sqlalchemy import event, nullsfirst, tuple_
+from sqlalchemy import event, false, nullsfirst, tuple_
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm.session import Session
 from sqlalchemy.sql import ColumnElement
@@ -339,4 +339,7 @@ def tuple_in_condition(
     """
     if settings.engine.dialect.name != "mssql":
         return tuple_(*columns).in_(collection)
-    return or_(*(and_(*(c == v for c, v in zip(columns, values))) for values in collection))
+    clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in collection]
+    if not clauses:
+        return false()
+    return or_(*clauses)