You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/08/08 17:43:28 UTC

[airflow] 04/37: Replace State by TaskInstanceState in Airflow executors (#32627)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-7-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 097d2bed4e372e4f4cc86b14c6b8a47dd8d65902
Author: Hussein Awala <hu...@awala.fr>
AuthorDate: Tue Aug 8 14:43:36 2023 +0200

    Replace State by TaskInstanceState in Airflow executors (#32627)
    
    * Replace State by TaskInstanceState in Airflow executors
    
    * chaneg state type in change_state method, KubernetesResultsType and KubernetesWatchType to TaskInstanceState
    
    * Fix change_state annotation in CeleryExecutor
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
    (cherry picked from commit 9556d6d5f611428ac8a3a5891647b720d4498ace)
---
 airflow/executors/base_executor.py                 |  8 +++----
 airflow/executors/debug_executor.py                | 28 +++++++++++-----------
 airflow/executors/sequential_executor.py           |  6 ++---
 .../providers/celery/executors/celery_executor.py  |  8 +++----
 .../kubernetes/executors/kubernetes_executor.py    |  9 +++----
 .../executors/kubernetes_executor_types.py         |  5 ++--
 .../executors/kubernetes_executor_utils.py         | 14 +++++++----
 7 files changed, 43 insertions(+), 35 deletions(-)

diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py
index 999125afe7..10aebbeb3d 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -32,7 +32,7 @@ from airflow.configuration import conf
 from airflow.exceptions import RemovedInAirflow3Warning
 from airflow.stats import Stats
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 PARALLELISM: int = conf.getint("core", "PARALLELISM")
 
@@ -295,7 +295,7 @@ class BaseExecutor(LoggingMixin):
             self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
             self.running.add(key)
 
-    def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
+    def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
         """
         Changes state of the task.
 
@@ -317,7 +317,7 @@ class BaseExecutor(LoggingMixin):
         :param info: Executor information for the task instance
         :param key: Unique key for the task instance
         """
-        self.change_state(key, State.FAILED, info)
+        self.change_state(key, TaskInstanceState.FAILED, info)
 
     def success(self, key: TaskInstanceKey, info=None) -> None:
         """
@@ -326,7 +326,7 @@ class BaseExecutor(LoggingMixin):
         :param info: Executor information for the task instance
         :param key: Unique key for the task instance
         """
-        self.change_state(key, State.SUCCESS, info)
+        self.change_state(key, TaskInstanceState.SUCCESS, info)
 
     def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
         """
diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py
index ca23b09a67..8a46d6cda0 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -29,7 +29,7 @@ import time
 from typing import TYPE_CHECKING, Any
 
 from airflow.executors.base_executor import BaseExecutor
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.models.taskinstance import TaskInstance
@@ -68,15 +68,15 @@ class DebugExecutor(BaseExecutor):
         while self.tasks_to_run:
             ti = self.tasks_to_run.pop(0)
             if self.fail_fast and not task_succeeded:
-                self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED)
-                ti.set_state(State.UPSTREAM_FAILED)
-                self.change_state(ti.key, State.UPSTREAM_FAILED)
+                self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED)
+                ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
+                self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)
                 continue
 
             if self._terminated.is_set():
-                self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED)
-                ti.set_state(State.FAILED)
-                self.change_state(ti.key, State.FAILED)
+                self.log.info("Executor is terminated! Stopping %s to %s", ti.key, TaskInstanceState.FAILED)
+                ti.set_state(TaskInstanceState.FAILED)
+                self.change_state(ti.key, TaskInstanceState.FAILED)
                 continue
 
             task_succeeded = self._run_task(ti)
@@ -87,11 +87,11 @@ class DebugExecutor(BaseExecutor):
         try:
             params = self.tasks_params.pop(ti.key, {})
             ti.run(job_id=ti.job_id, **params)
-            self.change_state(key, State.SUCCESS)
+            self.change_state(key, TaskInstanceState.SUCCESS)
             return True
         except Exception as e:
-            ti.set_state(State.FAILED)
-            self.change_state(key, State.FAILED)
+            ti.set_state(TaskInstanceState.FAILED)
+            self.change_state(key, TaskInstanceState.FAILED)
             self.log.exception("Failed to execute task: %s.", str(e))
             return False
 
@@ -148,14 +148,14 @@ class DebugExecutor(BaseExecutor):
     def end(self) -> None:
         """Set states of queued tasks to UPSTREAM_FAILED marking them as not executed."""
         for ti in self.tasks_to_run:
-            self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED)
-            ti.set_state(State.UPSTREAM_FAILED)
-            self.change_state(ti.key, State.UPSTREAM_FAILED)
+            self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED)
+            ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
+            self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)
 
     def terminate(self) -> None:
         self._terminated.set()
 
-    def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
+    def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
         self.log.debug("Popping %s from executor task queue.", key)
         self.running.remove(key)
         self.event_buffer[key] = state, info
diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py
index 28f88c6b87..2715edad6e 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -28,7 +28,7 @@ import subprocess
 from typing import TYPE_CHECKING, Any
 
 from airflow.executors.base_executor import BaseExecutor
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.executors.base_executor import CommandType
@@ -75,9 +75,9 @@ class SequentialExecutor(BaseExecutor):
 
             try:
                 subprocess.check_call(command, close_fds=True)
-                self.change_state(key, State.SUCCESS)
+                self.change_state(key, TaskInstanceState.SUCCESS)
             except subprocess.CalledProcessError as e:
-                self.change_state(key, State.FAILED)
+                self.change_state(key, TaskInstanceState.FAILED)
                 self.log.error("Failed to execute task %s.", str(e))
 
         self.commands_to_run = []
diff --git a/airflow/providers/celery/executors/celery_executor.py b/airflow/providers/celery/executors/celery_executor.py
index 51287f3c1b..4708ef2137 100644
--- a/airflow/providers/celery/executors/celery_executor.py
+++ b/airflow/providers/celery/executors/celery_executor.py
@@ -74,7 +74,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowTaskTimeout
 from airflow.executors.base_executor import BaseExecutor
 from airflow.stats import Stats
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 log = logging.getLogger(__name__)
 
@@ -299,7 +299,7 @@ class CeleryExecutor(BaseExecutor):
             self.task_publish_retries.pop(key, None)
             if isinstance(result, ExceptionWithTraceback):
                 self.log.error(CELERY_SEND_ERR_MSG_HEADER + ": %s\n%s\n", result.exception, result.traceback)
-                self.event_buffer[key] = (State.FAILED, None)
+                self.event_buffer[key] = (TaskInstanceState.FAILED, None)
             elif result is not None:
                 result.backend = cached_celery_backend
                 self.running.add(key)
@@ -308,7 +308,7 @@ class CeleryExecutor(BaseExecutor):
                 # Store the Celery task_id in the event buffer. This will get "overwritten" if the task
                 # has another event, but that is fine, because the only other events are success/failed at
                 # which point we don't need the ID anymore anyway
-                self.event_buffer[key] = (State.QUEUED, result.task_id)
+                self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id)
 
                 # If the task runs _really quickly_ we may already have a result!
                 self.update_task_state(key, result.state, getattr(result, "info", None))
@@ -355,7 +355,7 @@ class CeleryExecutor(BaseExecutor):
             if state:
                 self.update_task_state(key, state, info)
 
-    def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
+    def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
         super().change_state(key, state, info)
         self.tasks.pop(key, None)
 
diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
index a5aa36d981..051686c8d5 100644
--- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
+++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
@@ -78,7 +78,7 @@ from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import annota
 from airflow.utils.event_scheduler import EventScheduler
 from airflow.utils.log.logging_mixin import remove_escape_codes
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from kubernetes import client
@@ -425,7 +425,7 @@ class KubernetesExecutor(BaseExecutor):
     def _change_state(
         self,
         key: TaskInstanceKey,
-        state: str | None,
+        state: TaskInstanceState | None,
         pod_name: str,
         namespace: str,
         session: Session = NEW_SESSION,
@@ -433,12 +433,12 @@ class KubernetesExecutor(BaseExecutor):
         if TYPE_CHECKING:
             assert self.kube_scheduler
 
-        if state == State.RUNNING:
+        if state == TaskInstanceState.RUNNING:
             self.event_buffer[key] = state, None
             return
 
         if self.kube_config.delete_worker_pods:
-            if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure:
+            if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure:
                 self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace)
                 self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace))
         else:
@@ -455,6 +455,7 @@ class KubernetesExecutor(BaseExecutor):
             from airflow.models.taskinstance import TaskInstance
 
             state = session.query(TaskInstance.state).filter(TaskInstance.filter_for_tis([key])).scalar()
+            state = TaskInstanceState(state)
 
         self.event_buffer[key] = state, None
 
diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py
index a13cd35f8d..80b8f1de72 100644
--- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py
+++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py
@@ -21,15 +21,16 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
 if TYPE_CHECKING:
     from airflow.executors.base_executor import CommandType
     from airflow.models.taskinstance import TaskInstanceKey
+    from airflow.utils.state import TaskInstanceState
 
     # TaskInstance key, command, configuration, pod_template_file
     KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]
 
     # key, pod state, pod_name, namespace, resource_version
-    KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str]
+    KubernetesResultsType = Tuple[TaskInstanceKey, Optional[TaskInstanceState], str, str, str]
 
     # pod_name, namespace, pod state, annotations, resource_version
-    KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]
+    KubernetesWatchType = Tuple[str, str, Optional[TaskInstanceState], Dict[str, str], str]
 
 ALL_NAMESPACES = "ALL_NAMESPACES"
 POD_EXECUTOR_DONE_KEY = "airflow_executor_done"
diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
index c1ee9d1ebe..b19d88eb49 100644
--- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
+++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
@@ -36,7 +36,7 @@ from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
 )
 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
 
 try:
     from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import (
@@ -223,12 +223,16 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin):
             # since kube server have received request to delete pod set TI state failed
             if event["type"] == "DELETED" and pod.metadata.deletion_timestamp:
                 self.log.info("Event: Failed to start pod %s, annotations: %s", pod_name, annotations_string)
-                self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
+                self.watcher_queue.put(
+                    (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
+                )
             else:
                 self.log.debug("Event: %s Pending, annotations: %s", pod_name, annotations_string)
         elif status == "Failed":
             self.log.error("Event: %s Failed, annotations: %s", pod_name, annotations_string)
-            self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
+            self.watcher_queue.put(
+                (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
+            )
         elif status == "Succeeded":
             # We get multiple events once the pod hits a terminal state, and we only want to
             # send it along to the scheduler once.
@@ -256,7 +260,9 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin):
                     pod_name,
                     annotations_string,
                 )
-                self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
+                self.watcher_queue.put(
+                    (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
+                )
             else:
                 self.log.info("Event: %s is Running, annotations: %s", pod_name, annotations_string)
         else: