You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/10/31 05:31:33 UTC

[airflow] branch main updated: Api endpoint update ti (#26165)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e6f1d54c5 Api endpoint update ti (#26165)
1e6f1d54c5 is described below

commit 1e6f1d54c54e5dc50078216e23ba01560ebb133c
Author: Cedric Koffeto <38...@users.noreply.github.com>
AuthorDate: Mon Oct 31 06:31:26 2022 +0100

    Api endpoint update ti (#26165)
---
 .../endpoints/task_instance_endpoint.py            |  65 ++++++
 airflow/api_connexion/openapi/v1.yaml              |  78 +++++++
 .../api_connexion/schemas/task_instance_schema.py  |   8 +
 airflow/www/static/js/types/api-generated.ts       |  95 +++++++++
 .../endpoints/test_task_instance_endpoint.py       | 223 +++++++++++++++++++++
 5 files changed, 469 insertions(+)

diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 85fe4d66cd..415e90abb4 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -33,10 +33,12 @@ from airflow.api_connexion.schemas.task_instance_schema import (
     TaskInstanceCollection,
     TaskInstanceReferenceCollection,
     clear_task_instance_form,
+    set_single_task_instance_state_form,
     set_task_instance_state_form,
     task_instance_batch_form,
     task_instance_collection_schema,
     task_instance_reference_collection_schema,
+    task_instance_reference_schema,
     task_instance_schema,
 )
 from airflow.api_connexion.types import APIResponse
@@ -545,3 +547,66 @@ def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION
         session=session,
     )
     return task_instance_reference_collection_schema.dump(TaskInstanceReferenceCollection(task_instances=tis))
+
+
+@security.requires_access(
+    [
+        (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+        (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+        (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
+    ],
+)
+@provide_session
+def patch_task_instance(
+    *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION
+) -> APIResponse:
+    """Update the state of a task instance."""
+    body = get_json_request_dict()
+    try:
+        data = set_single_task_instance_state_form.load(body)
+    except ValidationError as err:
+        raise BadRequest(detail=str(err.messages))
+
+    dag = get_airflow_app().dag_bag.get_dag(dag_id)
+    if not dag:
+        raise NotFound("DAG not found", detail=f"DAG {dag_id!r} not found")
+
+    if not dag.has_task(task_id):
+        raise NotFound("Task not found", detail=f"Task {task_id!r} not found in DAG {dag_id!r}")
+
+    ti: TI | None = session.query(TI).get(
+        {'task_id': task_id, 'dag_id': dag_id, 'run_id': dag_run_id, 'map_index': map_index}
+    )
+
+    if not ti:
+        error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {dag_run_id!r}"
+        raise NotFound(detail=error_message)
+
+    if not data["dry_run"]:
+        ti = dag.set_task_instance_state(
+            task_id=task_id,
+            run_id=dag_run_id,
+            map_indexes=[map_index],
+            state=data["new_state"],
+            commit=True,
+            session=session,
+        )
+
+    return task_instance_reference_schema.dump(ti)
+
+
+@security.requires_access(
+    [
+        (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+        (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+        (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
+    ],
+)
+@provide_session
+def patch_mapped_task_instance(
+    *, dag_id: str, dag_run_id: str, task_id: str, map_index: int, session: Session = NEW_SESSION
+) -> APIResponse:
+    """Update the state of a mapped task instance."""
+    return patch_task_instance(
+        dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index, session=session
+    )
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index 950d86d16e..2a100affed 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1144,6 +1144,37 @@ paths:
         '404':
           $ref: '#/components/responses/NotFound'
 
+    patch:
+      summary: Updates the state of a task instance
+      description: >
+        Updates the state for single task instance.
+
+        *New in version 2.5.0*
+      x-openapi-router-controller: airflow.api_connexion.endpoints.task_instance_endpoint
+      operationId: patch_task_instance
+      tags: [TaskInstance]
+      requestBody:
+        description: Parameters of action
+        required: true
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/UpdateTaskInstance'
+      responses:
+        '200':
+          description: Success.
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/TaskInstanceReference'
+        '401':
+          $ref: '#/components/responses/Unauthenticated'
+        '403':
+          $ref: '#/components/responses/PermissionDenied'
+        '404':
+          $ref: '#/components/responses/NotFound'
+
+
   /dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}:
     parameters:
       - $ref: '#/components/parameters/DAGID'
@@ -1174,6 +1205,36 @@ paths:
         '404':
           $ref: '#/components/responses/NotFound'
 
+    patch:
+        summary: Updates the state of a mapped task instance
+        description: >
+            Updates the state for single mapped task instance.
+
+            *New in version 2.5.0*
+        x-openapi-router-controller: airflow.api_connexion.endpoints.task_instance_endpoint
+        operationId: patch_mapped_task_instance
+        tags: [TaskInstance]
+        requestBody:
+          description: Parameters of action
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/UpdateTaskInstance'
+        responses:
+          '200':
+            description: Success.
+            content:
+              application/json:
+                schema:
+                  $ref: '#/components/schemas/TaskInstanceReference'
+          '401':
+            $ref: '#/components/responses/Unauthenticated'
+          '403':
+            $ref: '#/components/responses/PermissionDenied'
+          '404':
+            $ref: '#/components/responses/NotFound'
+
+
   /dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/listMapped:
     parameters:
       - $ref: '#/components/parameters/DAGID'
@@ -3958,6 +4019,23 @@ components:
             - success
             - failed
 
+    UpdateTaskInstance:
+      type: object
+      properties:
+        dry_run:
+          description: |
+            If set, don't actually run this operation. The response will contain the task instance
+            planned to be affected, but won't be modified in any way.
+          type: boolean
+          default: false
+
+        new_state:
+          description: Expected new state.
+          type: string
+          enum:
+            - success
+            - failed
+
     ListDagRunsForm:
       type: object
       properties:
diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py
index 259a3a78a9..33c3d55069 100644
--- a/airflow/api_connexion/schemas/task_instance_schema.py
+++ b/airflow/api_connexion/schemas/task_instance_schema.py
@@ -166,6 +166,13 @@ class SetTaskInstanceStateFormSchema(Schema):
             raise ValidationError("Exactly one of execution_date or dag_run_id must be provided")
 
 
+class SetSingleTaskInstanceStateFormSchema(Schema):
+    """Schema for handling the request of updating state of a single task instance"""
+
+    dry_run = fields.Boolean(dump_default=True)
+    new_state = TaskInstanceStateField(required=True, validate=validate.OneOf([State.SUCCESS, State.FAILED]))
+
+
 class TaskInstanceReferenceSchema(Schema):
     """Schema for the task instance reference schema"""
 
@@ -192,5 +199,6 @@ task_instance_collection_schema = TaskInstanceCollectionSchema()
 task_instance_batch_form = TaskInstanceBatchFormSchema()
 clear_task_instance_form = ClearTaskInstanceFormSchema()
 set_task_instance_state_form = SetTaskInstanceStateFormSchema()
+set_single_task_instance_state_form = SetSingleTaskInstanceStateFormSchema()
 task_instance_reference_schema = TaskInstanceReferenceSchema()
 task_instance_reference_collection_schema = TaskInstanceReferenceCollectionSchema()
diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts
index dd69121e83..4d08be848d 100644
--- a/airflow/www/static/js/types/api-generated.ts
+++ b/airflow/www/static/js/types/api-generated.ts
@@ -285,6 +285,11 @@ export interface paths {
   };
   "/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}": {
     get: operations["get_task_instance"];
+    /**
+     * Updates the state for single task instance.
+     * *New in version 2.5.0*
+     */
+    patch: operations["patch_task_instance"];
     parameters: {
       path: {
         /** The DAG ID. */
@@ -303,6 +308,11 @@ export interface paths {
      * *New in version 2.3.0*
      */
     get: operations["get_mapped_task_instance"];
+    /**
+     * Updates the state for single mapped task instance.
+     * *New in version 2.5.0*
+     */
+    patch: operations["patch_mapped_task_instance"];
     parameters: {
       path: {
         /** The DAG ID. */
@@ -1733,6 +1743,20 @@ export interface components {
        */
       new_state?: "success" | "failed";
     };
+    UpdateTaskInstance: {
+      /**
+       * @description If set, don't actually run this operation. The response will contain the task instance
+       * planned to be affected, but won't be modified in any way.
+       *
+       * @default false
+       */
+      dry_run?: boolean;
+      /**
+       * @description Expected new state.
+       * @enum {string}
+       */
+      new_state?: "success" | "failed";
+    };
     ListDagRunsForm: {
       /**
        * @description The name of the field to order the results by. Prefix a field name
@@ -3137,6 +3161,39 @@ export interface operations {
       404: components["responses"]["NotFound"];
     };
   };
+  /**
+   * Updates the state for single task instance.
+   * *New in version 2.5.0*
+   */
+  patch_task_instance: {
+    parameters: {
+      path: {
+        /** The DAG ID. */
+        dag_id: components["parameters"]["DAGID"];
+        /** The DAG run ID. */
+        dag_run_id: components["parameters"]["DAGRunID"];
+        /** The task ID. */
+        task_id: components["parameters"]["TaskID"];
+      };
+    };
+    responses: {
+      /** Success. */
+      200: {
+        content: {
+          "application/json": components["schemas"]["TaskInstanceReference"];
+        };
+      };
+      401: components["responses"]["Unauthenticated"];
+      403: components["responses"]["PermissionDenied"];
+      404: components["responses"]["NotFound"];
+    };
+    /** Parameters of action */
+    requestBody: {
+      content: {
+        "application/json": components["schemas"]["UpdateTaskInstance"];
+      };
+    };
+  };
   /**
    * Get details of a mapped task instance.
    *
@@ -3167,6 +3224,41 @@ export interface operations {
       404: components["responses"]["NotFound"];
     };
   };
+  /**
+   * Updates the state for single mapped task instance.
+   * *New in version 2.5.0*
+   */
+  patch_mapped_task_instance: {
+    parameters: {
+      path: {
+        /** The DAG ID. */
+        dag_id: components["parameters"]["DAGID"];
+        /** The DAG run ID. */
+        dag_run_id: components["parameters"]["DAGRunID"];
+        /** The task ID. */
+        task_id: components["parameters"]["TaskID"];
+        /** The map index. */
+        map_index: components["parameters"]["MapIndex"];
+      };
+    };
+    responses: {
+      /** Success. */
+      200: {
+        content: {
+          "application/json": components["schemas"]["TaskInstanceReference"];
+        };
+      };
+      401: components["responses"]["Unauthenticated"];
+      403: components["responses"]["PermissionDenied"];
+      404: components["responses"]["NotFound"];
+    };
+    /** Parameters of action */
+    requestBody: {
+      content: {
+        "application/json": components["schemas"]["UpdateTaskInstance"];
+      };
+    };
+  };
   /**
    * Get details of all mapped task instances.
    *
@@ -4201,6 +4293,7 @@ export type VersionInfo = CamelCasedPropertiesDeep<components['schemas']['Versio
 export type ClearDagRun = CamelCasedPropertiesDeep<components['schemas']['ClearDagRun']>;
 export type ClearTaskInstances = CamelCasedPropertiesDeep<components['schemas']['ClearTaskInstances']>;
 export type UpdateTaskInstancesState = CamelCasedPropertiesDeep<components['schemas']['UpdateTaskInstancesState']>;
+export type UpdateTaskInstance = CamelCasedPropertiesDeep<components['schemas']['UpdateTaskInstance']>;
 export type ListDagRunsForm = CamelCasedPropertiesDeep<components['schemas']['ListDagRunsForm']>;
 export type ListTaskInstanceForm = CamelCasedPropertiesDeep<components['schemas']['ListTaskInstanceForm']>;
 export type ScheduleInterval = CamelCasedPropertiesDeep<components['schemas']['ScheduleInterval']>;
@@ -4255,7 +4348,9 @@ export type DeletePoolVariables = CamelCasedPropertiesDeep<operations['delete_po
 export type PatchPoolVariables = CamelCasedPropertiesDeep<operations['patch_pool']['parameters']['path'] & operations['patch_pool']['parameters']['query'] & operations['patch_pool']['requestBody']['content']['application/json']>;
 export type GetTaskInstancesVariables = CamelCasedPropertiesDeep<operations['get_task_instances']['parameters']['path'] & operations['get_task_instances']['parameters']['query']>;
 export type GetTaskInstanceVariables = CamelCasedPropertiesDeep<operations['get_task_instance']['parameters']['path']>;
+export type PatchTaskInstanceVariables = CamelCasedPropertiesDeep<operations['patch_task_instance']['parameters']['path'] & operations['patch_task_instance']['requestBody']['content']['application/json']>;
 export type GetMappedTaskInstanceVariables = CamelCasedPropertiesDeep<operations['get_mapped_task_instance']['parameters']['path']>;
+export type PatchMappedTaskInstanceVariables = CamelCasedPropertiesDeep<operations['patch_mapped_task_instance']['parameters']['path'] & operations['patch_mapped_task_instance']['requestBody']['content']['application/json']>;
 export type GetMappedTaskInstancesVariables = CamelCasedPropertiesDeep<operations['get_mapped_task_instances']['parameters']['path'] & operations['get_mapped_task_instances']['parameters']['query']>;
 export type GetTaskInstancesBatchVariables = CamelCasedPropertiesDeep<operations['get_task_instances_batch']['requestBody']['content']['application/json']>;
 export type GetVariablesVariables = CamelCasedPropertiesDeep<operations['get_variables']['parameters']['query']>;
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index e1dd9a6b64..399cdc7540 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -1582,3 +1582,226 @@ class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
         )
         assert response.status_code == 400
         assert response.json["detail"] == expected
+
+
+class TestPatchTaskInstance(TestTaskInstanceEndpoint):
+    ENDPOINT_URL = (
+        "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context"
+    )
+
+    @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+    def test_should_call_mocked_api(self, mock_set_task_instance_state, session):
+        self.create_task_instances(session)
+
+        NEW_STATE = "failed"
+        mock_set_task_instance_state.return_value = session.query(TaskInstance).get(
+            {
+                "task_id": "print_the_context",
+                "dag_id": "example_python_operator",
+                "run_id": "TEST_DAG_RUN_ID",
+                "map_index": -1,
+            }
+        )
+        response = self.client.patch(
+            self.ENDPOINT_URL,
+            environ_overrides={"REMOTE_USER": "test"},
+            json={
+                "dry_run": False,
+                "new_state": NEW_STATE,
+            },
+        )
+        assert response.status_code == 200
+        assert response.json == {
+            "dag_id": "example_python_operator",
+            "dag_run_id": "TEST_DAG_RUN_ID",
+            "execution_date": "2020-01-01T00:00:00+00:00",
+            "task_id": "print_the_context",
+        }
+
+        mock_set_task_instance_state.assert_called_once_with(
+            task_id="print_the_context",
+            run_id="TEST_DAG_RUN_ID",
+            map_indexes=[-1],
+            state=NEW_STATE,
+            commit=True,
+            session=session,
+        )
+
+    @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+    def test_should_not_call_mocked_api_for_dry_run(self, mock_set_task_instance_state, session):
+        self.create_task_instances(session)
+
+        NEW_STATE = "failed"
+        mock_set_task_instance_state.return_value = session.query(TaskInstance).get(
+            {
+                "task_id": "print_the_context",
+                "dag_id": "example_python_operator",
+                "run_id": "TEST_DAG_RUN_ID",
+                "map_index": -1,
+            }
+        )
+        response = self.client.patch(
+            self.ENDPOINT_URL,
+            environ_overrides={"REMOTE_USER": "test"},
+            json={
+                "dry_run": True,
+                "new_state": NEW_STATE,
+            },
+        )
+        assert response.status_code == 200
+        print(response.status_code)
+        assert response.json == {
+            "dag_id": "example_python_operator",
+            "dag_run_id": "TEST_DAG_RUN_ID",
+            "execution_date": "2020-01-01T00:00:00+00:00",
+            "task_id": "print_the_context",
+        }
+
+        mock_set_task_instance_state.assert_not_called()
+
+    def test_should_update_task_instance_state(self, session):
+        self.create_task_instances(session)
+
+        NEW_STATE = "failed"
+
+        self.client.patch(
+            self.ENDPOINT_URL,
+            environ_overrides={"REMOTE_USER": "test"},
+            json={
+                "dry_run": False,
+                "new_state": NEW_STATE,
+            },
+        )
+
+        response2 = self.client.get(
+            self.ENDPOINT_URL,
+            environ_overrides={"REMOTE_USER": "test"},
+            json={},
+        )
+        assert response2.status_code == 200
+        assert response2.json["state"] == NEW_STATE
+
+    def test_should_update_mapped_task_instance_state(self, session):
+
+        NEW_STATE = "failed"
+        map_index = 1
+
+        tis = self.create_task_instances(session)
+        ti = tis[0]
+        ti.map_index = map_index
+        rendered_fields = RTIF(ti, render_templates=False)
+        session.add(rendered_fields)
+        session.commit()
+
+        self.client.patch(
+            f"{self.ENDPOINT_URL}/{map_index}",
+            environ_overrides={"REMOTE_USER": "test"},
+            json={
+                "dry_run": False,
+                "new_state": NEW_STATE,
+            },
+        )
+
+        response2 = self.client.get(
+            f"{self.ENDPOINT_URL}/{map_index}",
+            environ_overrides={"REMOTE_USER": "test"},
+            json={},
+        )
+        assert response2.status_code == 200
+        assert response2.json["state"] == NEW_STATE
+
+    @pytest.mark.parametrize(
+        "error, code, payload",
+        [
+            [
+                "Task instance not found for task 'print_the_context' on DAG run with ID 'TEST_DAG_RUN_ID'",
+                404,
+                {
+                    "dry_run": True,
+                    "new_state": "failed",
+                },
+            ]
+        ],
+    )
+    def test_should_handle_errors(self, error, code, payload, session):
+        response = self.client.patch(
+            self.ENDPOINT_URL,
+            environ_overrides={"REMOTE_USER": "test"},
+            json=payload,
+        )
+        assert response.status_code == code
+        assert response.json["detail"] == error
+
+    def test_should_raises_401_unauthenticated(self):
+        response = self.client.patch(
+            self.ENDPOINT_URL,
+            json={
+                "dry_run": False,
+                "new_state": "failed",
+            },
+        )
+        assert_401(response)
+
+    @parameterized.expand(["test_no_permissions", "test_dag_read_only", "test_task_read_only"])
+    def test_should_raise_403_forbidden(self, username):
+        response = self.client.patch(
+            self.ENDPOINT_URL,
+            environ_overrides={"REMOTE_USER": username},
+            json={
+                "dry_run": True,
+                "new_state": "failed",
+            },
+        )
+        assert response.status_code == 403
+
+    def test_should_raise_404_not_found_dag(self):
+        response = self.client.patch(
+            self.ENDPOINT_URL,
+            environ_overrides={"REMOTE_USER": "test"},
+            json={
+                "dry_run": True,
+                "new_state": "failed",
+            },
+        )
+        assert response.status_code == 404
+
+    def test_should_raise_404_not_found_task(self):
+        response = self.client.patch(
+            self.ENDPOINT_URL,
+            environ_overrides={"REMOTE_USER": "test"},
+            json={
+                "dry_run": True,
+                "new_state": "failed",
+            },
+        )
+        assert response.status_code == 404
+
+    @pytest.mark.parametrize(
+        "payload, expected",
+        [
+            (
+                {
+                    "dry_run": True,
+                    "new_state": "failede",
+                },
+                f"'failede' is not one of ['{State.SUCCESS}', '{State.FAILED}'] - 'new_state'",
+            ),
+            (
+                {
+                    "dry_run": True,
+                    "new_state": "queued",
+                },
+                f"'queued' is not one of ['{State.SUCCESS}', '{State.FAILED}'] - 'new_state'",
+            ),
+        ],
+    )
+    @provide_session
+    def test_should_raise_400_for_invalid_task_instance_state(self, payload, expected, session):
+        self.create_task_instances(session)
+        response = self.client.patch(
+            self.ENDPOINT_URL,
+            environ_overrides={"REMOTE_USER": "test"},
+            json=payload,
+        )
+        assert response.status_code == 400
+        assert response.json["detail"] == expected