You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/11/20 16:10:12 UTC
[airflow] branch main updated: Fix task instance api cannot list task instances with None state (#19487)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f636060 Fix task instance api cannot list task instances with None state (#19487)
f636060 is described below
commit f636060fd7b5eb8facd1acb10a731d4e03bc864a
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Sat Nov 20 17:09:33 2021 +0100
Fix task instance api cannot list task instances with None state (#19487)
* Fix task instance api cannot list task instances with None state
The task instance state can be None and in the API we accept `none` for null state.
This PR fixes this issue by converting the `none` to None and improving the query
so that the DB can get this state.
---
.../endpoints/task_instance_endpoint.py | 20 ++++++++++---
airflow/api_connexion/openapi/v1.yaml | 2 +-
.../api_connexion/schemas/task_instance_schema.py | 2 +-
.../endpoints/test_task_instance_endpoint.py | 35 +++++++++++++++++++---
4 files changed, 49 insertions(+), 10 deletions(-)
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index c4f464a..60f847d 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -20,6 +20,7 @@ from flask import current_app, request
from marshmallow import ValidationError
from sqlalchemy import and_, func
from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.sql.expression import or_
from airflow.api_connexion import security
from airflow.api_connexion.exceptions import BadRequest, NotFound
@@ -73,9 +74,16 @@ def get_task_instance(dag_id: str, dag_run_id: str, task_id: str, session=None):
return task_instance_schema.dump(task_instance)
+def _convert_state(states):
+ if not states:
+ return None
+ return [State.NONE if s == "none" else s for s in states]
+
+
def _apply_array_filter(query, key, values):
if values is not None:
- query = query.filter(key.in_(values))
+ cond = ((key == v) for v in values)
+ query = query.filter(or_(*cond))
return query
@@ -118,13 +126,16 @@ def get_task_instances(
end_date_lte: Optional[str] = None,
duration_gte: Optional[float] = None,
duration_lte: Optional[float] = None,
- state: Optional[str] = None,
+ state: Optional[List[str]] = None,
pool: Optional[List[str]] = None,
queue: Optional[List[str]] = None,
offset: Optional[int] = None,
session=None,
):
"""Get list of task instances."""
+ # Because state can be 'none'
+ states = _convert_state(state)
+
base_query = session.query(TI).join(TI.dag_run)
if dag_id != "~":
@@ -141,7 +152,7 @@ def get_task_instances(
)
base_query = _apply_range_filter(base_query, key=TI.end_date, value_range=(end_date_gte, end_date_lte))
base_query = _apply_range_filter(base_query, key=TI.duration, value_range=(duration_gte, duration_lte))
- base_query = _apply_array_filter(base_query, key=TI.state, values=state)
+ base_query = _apply_array_filter(base_query, key=TI.state, values=states)
base_query = _apply_array_filter(base_query, key=TI.pool, values=pool)
base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)
@@ -180,6 +191,7 @@ def get_task_instances_batch(session=None):
data = task_instance_batch_form.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
+ states = _convert_state(data['state'])
base_query = session.query(TI).join(TI.dag_run)
base_query = _apply_array_filter(base_query, key=TI.dag_id, values=data["dag_ids"])
@@ -199,7 +211,7 @@ def get_task_instances_batch(session=None):
base_query = _apply_range_filter(
base_query, key=TI.duration, value_range=(data["duration_gte"], data["duration_lte"])
)
- base_query = _apply_array_filter(base_query, key=TI.state, values=data["state"])
+ base_query = _apply_array_filter(base_query, key=TI.state, values=states)
base_query = _apply_array_filter(base_query, key=TI.pool, values=data["pool"])
base_query = _apply_array_filter(base_query, key=TI.queue, values=data["queue"])
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index a4d75c5..5ff3937 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -3237,7 +3237,7 @@ components:
state:
type: array
items:
- type: string
+ $ref: '#/components/schemas/TaskState'
description:
The value can be repeated to retrieve multiple matching values (OR condition).
pool:
diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py
index 956246b..76d284d 100644
--- a/airflow/api_connexion/schemas/task_instance_schema.py
+++ b/airflow/api_connexion/schemas/task_instance_schema.py
@@ -43,7 +43,7 @@ class TaskInstanceSchema(SQLAlchemySchema):
start_date = auto_field()
end_date = auto_field()
duration = auto_field()
- state = auto_field()
+ state = TaskInstanceStateField()
_try_number = auto_field(data_key="try_number")
max_tries = auto_field()
hostname = auto_field()
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 2a2e0f1..4e887c5 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -370,13 +370,26 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
{"state": State.RUNNING},
{"state": State.QUEUED},
{"state": State.SUCCESS},
+ {"state": State.NONE},
],
False,
(
"/api/v1/dags/example_python_operator/dagRuns/"
- "TEST_DAG_RUN_ID/taskInstances?state=running,queued"
+ "TEST_DAG_RUN_ID/taskInstances?state=running,queued,none"
),
- 2,
+ 3,
+ ),
+ (
+ "test null states with no filter",
+ [
+ {"state": State.NONE},
+ {"state": State.NONE},
+ {"state": State.NONE},
+ {"state": State.NONE},
+ ],
+ False,
+ ("/api/v1/dags/example_python_operator/dagRuns/" "TEST_DAG_RUN_ID/taskInstances"),
+ 4,
),
(
"test pool filter",
@@ -503,10 +516,24 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
{"state": State.RUNNING},
{"state": State.QUEUED},
{"state": State.SUCCESS},
+ {"state": State.NONE},
],
False,
- {"state": ["running", "queued"]},
- 2,
+ {"state": ["running", "queued", "none"]},
+ 3,
+ "test_task_read_only",
+ ),
+ (
+ "test dag with null states",
+ [
+ {"state": State.NONE},
+ {"state": State.NONE},
+ {"state": State.NONE},
+ {"state": State.NONE},
+ ],
+ False,
+ {},
+ 4,
"test_task_read_only",
),
(