You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/08/08 17:43:28 UTC
[airflow] 04/37: Replace State by TaskInstanceState in Airflow executors (#32627)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v2-7-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 097d2bed4e372e4f4cc86b14c6b8a47dd8d65902
Author: Hussein Awala <hu...@awala.fr>
AuthorDate: Tue Aug 8 14:43:36 2023 +0200
Replace State by TaskInstanceState in Airflow executors (#32627)
* Replace State by TaskInstanceState in Airflow executors
* chaneg state type in change_state method, KubernetesResultsType and KubernetesWatchType to TaskInstanceState
* Fix change_state annotation in CeleryExecutor
---------
Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
(cherry picked from commit 9556d6d5f611428ac8a3a5891647b720d4498ace)
---
airflow/executors/base_executor.py | 8 +++----
airflow/executors/debug_executor.py | 28 +++++++++++-----------
airflow/executors/sequential_executor.py | 6 ++---
.../providers/celery/executors/celery_executor.py | 8 +++----
.../kubernetes/executors/kubernetes_executor.py | 9 +++----
.../executors/kubernetes_executor_types.py | 5 ++--
.../executors/kubernetes_executor_utils.py | 14 +++++++----
7 files changed, 43 insertions(+), 35 deletions(-)
diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py
index 999125afe7..10aebbeb3d 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -32,7 +32,7 @@ from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
PARALLELISM: int = conf.getint("core", "PARALLELISM")
@@ -295,7 +295,7 @@ class BaseExecutor(LoggingMixin):
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)
- def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
+ def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
"""
Changes state of the task.
@@ -317,7 +317,7 @@ class BaseExecutor(LoggingMixin):
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
- self.change_state(key, State.FAILED, info)
+ self.change_state(key, TaskInstanceState.FAILED, info)
def success(self, key: TaskInstanceKey, info=None) -> None:
"""
@@ -326,7 +326,7 @@ class BaseExecutor(LoggingMixin):
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
- self.change_state(key, State.SUCCESS, info)
+ self.change_state(key, TaskInstanceState.SUCCESS, info)
def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py
index ca23b09a67..8a46d6cda0 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -29,7 +29,7 @@ import time
from typing import TYPE_CHECKING, Any
from airflow.executors.base_executor import BaseExecutor
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
@@ -68,15 +68,15 @@ class DebugExecutor(BaseExecutor):
while self.tasks_to_run:
ti = self.tasks_to_run.pop(0)
if self.fail_fast and not task_succeeded:
- self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED)
- ti.set_state(State.UPSTREAM_FAILED)
- self.change_state(ti.key, State.UPSTREAM_FAILED)
+ self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED)
+ ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
+ self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)
continue
if self._terminated.is_set():
- self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED)
- ti.set_state(State.FAILED)
- self.change_state(ti.key, State.FAILED)
+ self.log.info("Executor is terminated! Stopping %s to %s", ti.key, TaskInstanceState.FAILED)
+ ti.set_state(TaskInstanceState.FAILED)
+ self.change_state(ti.key, TaskInstanceState.FAILED)
continue
task_succeeded = self._run_task(ti)
@@ -87,11 +87,11 @@ class DebugExecutor(BaseExecutor):
try:
params = self.tasks_params.pop(ti.key, {})
ti.run(job_id=ti.job_id, **params)
- self.change_state(key, State.SUCCESS)
+ self.change_state(key, TaskInstanceState.SUCCESS)
return True
except Exception as e:
- ti.set_state(State.FAILED)
- self.change_state(key, State.FAILED)
+ ti.set_state(TaskInstanceState.FAILED)
+ self.change_state(key, TaskInstanceState.FAILED)
self.log.exception("Failed to execute task: %s.", str(e))
return False
@@ -148,14 +148,14 @@ class DebugExecutor(BaseExecutor):
def end(self) -> None:
"""Set states of queued tasks to UPSTREAM_FAILED marking them as not executed."""
for ti in self.tasks_to_run:
- self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED)
- ti.set_state(State.UPSTREAM_FAILED)
- self.change_state(ti.key, State.UPSTREAM_FAILED)
+ self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED)
+ ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
+ self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)
def terminate(self) -> None:
self._terminated.set()
- def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
+ def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
self.log.debug("Popping %s from executor task queue.", key)
self.running.remove(key)
self.event_buffer[key] = state, info
diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py
index 28f88c6b87..2715edad6e 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -28,7 +28,7 @@ import subprocess
from typing import TYPE_CHECKING, Any
from airflow.executors.base_executor import BaseExecutor
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
@@ -75,9 +75,9 @@ class SequentialExecutor(BaseExecutor):
try:
subprocess.check_call(command, close_fds=True)
- self.change_state(key, State.SUCCESS)
+ self.change_state(key, TaskInstanceState.SUCCESS)
except subprocess.CalledProcessError as e:
- self.change_state(key, State.FAILED)
+ self.change_state(key, TaskInstanceState.FAILED)
self.log.error("Failed to execute task %s.", str(e))
self.commands_to_run = []
diff --git a/airflow/providers/celery/executors/celery_executor.py b/airflow/providers/celery/executors/celery_executor.py
index 51287f3c1b..4708ef2137 100644
--- a/airflow/providers/celery/executors/celery_executor.py
+++ b/airflow/providers/celery/executors/celery_executor.py
@@ -74,7 +74,7 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowTaskTimeout
from airflow.executors.base_executor import BaseExecutor
from airflow.stats import Stats
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
log = logging.getLogger(__name__)
@@ -299,7 +299,7 @@ class CeleryExecutor(BaseExecutor):
self.task_publish_retries.pop(key, None)
if isinstance(result, ExceptionWithTraceback):
self.log.error(CELERY_SEND_ERR_MSG_HEADER + ": %s\n%s\n", result.exception, result.traceback)
- self.event_buffer[key] = (State.FAILED, None)
+ self.event_buffer[key] = (TaskInstanceState.FAILED, None)
elif result is not None:
result.backend = cached_celery_backend
self.running.add(key)
@@ -308,7 +308,7 @@ class CeleryExecutor(BaseExecutor):
# Store the Celery task_id in the event buffer. This will get "overwritten" if the task
# has another event, but that is fine, because the only other events are success/failed at
# which point we don't need the ID anymore anyway
- self.event_buffer[key] = (State.QUEUED, result.task_id)
+ self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id)
# If the task runs _really quickly_ we may already have a result!
self.update_task_state(key, result.state, getattr(result, "info", None))
@@ -355,7 +355,7 @@ class CeleryExecutor(BaseExecutor):
if state:
self.update_task_state(key, state, info)
- def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
+ def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
super().change_state(key, state, info)
self.tasks.pop(key, None)
diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
index a5aa36d981..051686c8d5 100644
--- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
+++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
@@ -78,7 +78,7 @@ from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import annota
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.log.logging_mixin import remove_escape_codes
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from kubernetes import client
@@ -425,7 +425,7 @@ class KubernetesExecutor(BaseExecutor):
def _change_state(
self,
key: TaskInstanceKey,
- state: str | None,
+ state: TaskInstanceState | None,
pod_name: str,
namespace: str,
session: Session = NEW_SESSION,
@@ -433,12 +433,12 @@ class KubernetesExecutor(BaseExecutor):
if TYPE_CHECKING:
assert self.kube_scheduler
- if state == State.RUNNING:
+ if state == TaskInstanceState.RUNNING:
self.event_buffer[key] = state, None
return
if self.kube_config.delete_worker_pods:
- if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure:
+ if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure:
self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace)
self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace))
else:
@@ -455,6 +455,7 @@ class KubernetesExecutor(BaseExecutor):
from airflow.models.taskinstance import TaskInstance
state = session.query(TaskInstance.state).filter(TaskInstance.filter_for_tis([key])).scalar()
+ state = TaskInstanceState(state)
self.event_buffer[key] = state, None
diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py
index a13cd35f8d..80b8f1de72 100644
--- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py
+++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py
@@ -21,15 +21,16 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
from airflow.models.taskinstance import TaskInstanceKey
+ from airflow.utils.state import TaskInstanceState
# TaskInstance key, command, configuration, pod_template_file
KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]
# key, pod state, pod_name, namespace, resource_version
- KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str]
+ KubernetesResultsType = Tuple[TaskInstanceKey, Optional[TaskInstanceState], str, str, str]
# pod_name, namespace, pod state, annotations, resource_version
- KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]
+ KubernetesWatchType = Tuple[str, str, Optional[TaskInstanceState], Dict[str, str], str]
ALL_NAMESPACES = "ALL_NAMESPACES"
POD_EXECUTOR_DONE_KEY = "airflow_executor_done"
diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
index c1ee9d1ebe..b19d88eb49 100644
--- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
+++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
@@ -36,7 +36,7 @@ from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
)
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.state import State
+from airflow.utils.state import TaskInstanceState
try:
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import (
@@ -223,12 +223,16 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin):
# since kube server have received request to delete pod set TI state failed
if event["type"] == "DELETED" and pod.metadata.deletion_timestamp:
self.log.info("Event: Failed to start pod %s, annotations: %s", pod_name, annotations_string)
- self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
+ self.watcher_queue.put(
+ (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
+ )
else:
self.log.debug("Event: %s Pending, annotations: %s", pod_name, annotations_string)
elif status == "Failed":
self.log.error("Event: %s Failed, annotations: %s", pod_name, annotations_string)
- self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
+ self.watcher_queue.put(
+ (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
+ )
elif status == "Succeeded":
# We get multiple events once the pod hits a terminal state, and we only want to
# send it along to the scheduler once.
@@ -256,7 +260,9 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin):
pod_name,
annotations_string,
)
- self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
+ self.watcher_queue.put(
+ (pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
+ )
else:
self.log.info("Event: %s is Running, annotations: %s", pod_name, annotations_string)
else: