You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/11/20 16:10:12 UTC

[airflow] branch main updated: Fix task instance api cannot list task instances with None state (#19487)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f636060  Fix task instance api cannot list task instances with None state (#19487)
f636060 is described below

commit f636060fd7b5eb8facd1acb10a731d4e03bc864a
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Sat Nov 20 17:09:33 2021 +0100

    Fix task instance api cannot list task instances with None state (#19487)
    
    * Fix task instance api cannot list task instances with None state
    
    The task instance state can be None and in the API we accept `none` for null state.
    
    This PR fixes this issue by converting the `none` to None and improving the query
    so that the DB can get this state.
---
 .../endpoints/task_instance_endpoint.py            | 20 ++++++++++---
 airflow/api_connexion/openapi/v1.yaml              |  2 +-
 .../api_connexion/schemas/task_instance_schema.py  |  2 +-
 .../endpoints/test_task_instance_endpoint.py       | 35 +++++++++++++++++++---
 4 files changed, 49 insertions(+), 10 deletions(-)

diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index c4f464a..60f847d 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -20,6 +20,7 @@ from flask import current_app, request
 from marshmallow import ValidationError
 from sqlalchemy import and_, func
 from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.sql.expression import or_
 
 from airflow.api_connexion import security
 from airflow.api_connexion.exceptions import BadRequest, NotFound
@@ -73,9 +74,16 @@ def get_task_instance(dag_id: str, dag_run_id: str, task_id: str, session=None):
     return task_instance_schema.dump(task_instance)
 
 
+def _convert_state(states):
+    if not states:
+        return None
+    return [State.NONE if s == "none" else s for s in states]
+
+
 def _apply_array_filter(query, key, values):
     if values is not None:
-        query = query.filter(key.in_(values))
+        cond = ((key == v) for v in values)
+        query = query.filter(or_(*cond))
     return query
 
 
@@ -118,13 +126,16 @@ def get_task_instances(
     end_date_lte: Optional[str] = None,
     duration_gte: Optional[float] = None,
     duration_lte: Optional[float] = None,
-    state: Optional[str] = None,
+    state: Optional[List[str]] = None,
     pool: Optional[List[str]] = None,
     queue: Optional[List[str]] = None,
     offset: Optional[int] = None,
     session=None,
 ):
     """Get list of task instances."""
+    # Because state can be 'none'
+    states = _convert_state(state)
+
     base_query = session.query(TI).join(TI.dag_run)
 
     if dag_id != "~":
@@ -141,7 +152,7 @@ def get_task_instances(
     )
     base_query = _apply_range_filter(base_query, key=TI.end_date, value_range=(end_date_gte, end_date_lte))
     base_query = _apply_range_filter(base_query, key=TI.duration, value_range=(duration_gte, duration_lte))
-    base_query = _apply_array_filter(base_query, key=TI.state, values=state)
+    base_query = _apply_array_filter(base_query, key=TI.state, values=states)
     base_query = _apply_array_filter(base_query, key=TI.pool, values=pool)
     base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)
 
@@ -180,6 +191,7 @@ def get_task_instances_batch(session=None):
         data = task_instance_batch_form.load(body)
     except ValidationError as err:
         raise BadRequest(detail=str(err.messages))
+    states = _convert_state(data['state'])
     base_query = session.query(TI).join(TI.dag_run)
 
     base_query = _apply_array_filter(base_query, key=TI.dag_id, values=data["dag_ids"])
@@ -199,7 +211,7 @@ def get_task_instances_batch(session=None):
     base_query = _apply_range_filter(
         base_query, key=TI.duration, value_range=(data["duration_gte"], data["duration_lte"])
     )
-    base_query = _apply_array_filter(base_query, key=TI.state, values=data["state"])
+    base_query = _apply_array_filter(base_query, key=TI.state, values=states)
     base_query = _apply_array_filter(base_query, key=TI.pool, values=data["pool"])
     base_query = _apply_array_filter(base_query, key=TI.queue, values=data["queue"])
 
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index a4d75c5..5ff3937 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -3237,7 +3237,7 @@ components:
         state:
           type: array
           items:
-            type: string
+            $ref: '#/components/schemas/TaskState'
           description:
             The value can be repeated to retrieve multiple matching values (OR condition).
         pool:
diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py
index 956246b..76d284d 100644
--- a/airflow/api_connexion/schemas/task_instance_schema.py
+++ b/airflow/api_connexion/schemas/task_instance_schema.py
@@ -43,7 +43,7 @@ class TaskInstanceSchema(SQLAlchemySchema):
     start_date = auto_field()
     end_date = auto_field()
     duration = auto_field()
-    state = auto_field()
+    state = TaskInstanceStateField()
     _try_number = auto_field(data_key="try_number")
     max_tries = auto_field()
     hostname = auto_field()
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 2a2e0f1..4e887c5 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -370,13 +370,26 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                     {"state": State.RUNNING},
                     {"state": State.QUEUED},
                     {"state": State.SUCCESS},
+                    {"state": State.NONE},
                 ],
                 False,
                 (
                     "/api/v1/dags/example_python_operator/dagRuns/"
-                    "TEST_DAG_RUN_ID/taskInstances?state=running,queued"
+                    "TEST_DAG_RUN_ID/taskInstances?state=running,queued,none"
                 ),
-                2,
+                3,
+            ),
+            (
+                "test null states with no filter",
+                [
+                    {"state": State.NONE},
+                    {"state": State.NONE},
+                    {"state": State.NONE},
+                    {"state": State.NONE},
+                ],
+                False,
+                ("/api/v1/dags/example_python_operator/dagRuns/" "TEST_DAG_RUN_ID/taskInstances"),
+                4,
             ),
             (
                 "test pool filter",
@@ -503,10 +516,24 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
                     {"state": State.RUNNING},
                     {"state": State.QUEUED},
                     {"state": State.SUCCESS},
+                    {"state": State.NONE},
                 ],
                 False,
-                {"state": ["running", "queued"]},
-                2,
+                {"state": ["running", "queued", "none"]},
+                3,
+                "test_task_read_only",
+            ),
+            (
+                "test dag with null states",
+                [
+                    {"state": State.NONE},
+                    {"state": State.NONE},
+                    {"state": State.NONE},
+                    {"state": State.NONE},
+                ],
+                False,
+                {},
+                4,
                 "test_task_read_only",
             ),
             (