You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/06/23 21:45:01 UTC
[airflow] 01/01: Monitor pods by labels instead of names (#6377)
This is an automated email from the ASF dual-hosted git repository.
dimberman pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 34eae6f7cb041fb3edd0ee65552649183df6e89a
Author: Daniel Imberman <da...@gmail.com>
AuthorDate: Sat May 16 14:13:58 2020 -0700
Monitor pods by labels instead of names (#6377)
* Monitor k8sPodOperator pods by labels
To prevent situations where the scheduler starts a
second k8sPodOperator pod after a restart, we now check
for existing pods using kubernetes labels
* Update airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Co-authored-by: Kaxil Naik <ka...@gmail.com>
* Update airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Co-authored-by: Kaxil Naik <ka...@gmail.com>
* add docs
* Update airflow/kubernetes/pod_launcher.py
Co-authored-by: Kaxil Naik <ka...@gmail.com>
Co-authored-by: Daniel Imberman <da...@astronomer.io>
Co-authored-by: Kaxil Naik <ka...@gmail.com>
(cherry picked from commit 8985df0bfcb5f2b2cd69a21b9814021f9f8ce953)
---
airflow/executors/kubernetes_executor.py | 58 ++--
airflow/kubernetes/pod_generator.py | 41 ++-
airflow/kubernetes/pod_launcher.py | 40 ++-
.../cncf/kubernetes/operators/kubernetes_pod.py | 301 ++++++++++++++-------
kubernetes_tests/test_kubernetes_pod_operator.py | 185 ++++++++++---
tests/executors/test_kubernetes_executor.py | 18 +-
6 files changed, 460 insertions(+), 183 deletions(-)
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index 8b5fdc1..625e06b 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -16,37 +16,32 @@
# under the License.
"""Kubernetes executor"""
import base64
-import hashlib
-from queue import Empty
-
-import re
import json
import multiprocessing
-from uuid import uuid4
import time
-
-from dateutil import parser
+from queue import Empty
+from uuid import uuid4
import kubernetes
+from dateutil import parser
from kubernetes import watch, client
from kubernetes.client.rest import ApiException
from urllib3.exceptions import HTTPError, ReadTimeoutError
+from airflow import settings
from airflow.configuration import conf
-from airflow.kubernetes.pod_launcher import PodLauncher
+from airflow.exceptions import AirflowConfigException, AirflowException
+from airflow.executors.base_executor import BaseExecutor
+from airflow.kubernetes import pod_generator
from airflow.kubernetes.kube_client import get_kube_client
-from airflow.kubernetes.worker_configuration import WorkerConfiguration
+from airflow.kubernetes.pod_generator import MAX_POD_ID_LEN, MAX_LABEL_LEN
from airflow.kubernetes.pod_generator import PodGenerator
-from airflow.executors.base_executor import BaseExecutor
+from airflow.kubernetes.pod_launcher import PodLauncher
+from airflow.kubernetes.worker_configuration import WorkerConfiguration
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier, TaskInstance
-from airflow.utils.state import State
from airflow.utils.db import provide_session, create_session
-from airflow import settings
-from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.utils.log.logging_mixin import LoggingMixin
-
-MAX_POD_ID_LEN = 253
-MAX_LABEL_LEN = 63
+from airflow.utils.state import State
class KubeConfig:
@@ -402,8 +397,8 @@ class AirflowKubernetesScheduler(LoggingMixin):
namespace=self.namespace,
worker_uuid=self.worker_uuid,
pod_id=self._create_pod_id(dag_id, task_id),
- dag_id=self._make_safe_label_value(dag_id),
- task_id=self._make_safe_label_value(task_id),
+ dag_id=pod_generator.make_safe_label_value(dag_id),
+ task_id=pod_generator.make_safe_label_value(task_id),
try_number=try_number,
execution_date=self._datetime_to_label_safe_datestring(execution_date),
airflow_command=command
@@ -495,25 +490,6 @@ class AirflowKubernetesScheduler(LoggingMixin):
return safe_pod_id
@staticmethod
- def _make_safe_label_value(string):
- """
- Valid label values must be 63 characters or less and must be empty or begin and
- end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_),
- dots (.), and alphanumerics between.
-
- If the label value is then greater than 63 chars once made safe, or differs in any
- way from the original value sent to this function, then we need to truncate to
- 53chars, and append it with a unique hash.
- """
- safe_label = re.sub(r'^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$', '', string)
-
- if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
- safe_hash = hashlib.md5(string.encode()).hexdigest()[:9]
- safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash
-
- return safe_label
-
- @staticmethod
def _create_pod_id(dag_id, task_id):
safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
dag_id)
@@ -599,8 +575,8 @@ class AirflowKubernetesScheduler(LoggingMixin):
)
for task in tasks:
if (
- self._make_safe_label_value(task.dag_id) == dag_id and
- self._make_safe_label_value(task.task_id) == task_id and
+ pod_generator.make_safe_label_value(task.dag_id) == dag_id and
+ pod_generator.make_safe_label_value(task.task_id) == task_id and
task.execution_date == ex_time
):
self.log.info(
@@ -683,8 +659,8 @@ class KubernetesExecutor(BaseExecutor, LoggingMixin):
# pylint: disable=protected-access
dict_string = (
"dag_id={},task_id={},execution_date={},airflow-worker={}".format(
- AirflowKubernetesScheduler._make_safe_label_value(task.dag_id),
- AirflowKubernetesScheduler._make_safe_label_value(task.task_id),
+ pod_generator.make_safe_label_value(task.dag_id),
+ pod_generator.make_safe_label_value(task.task_id),
AirflowKubernetesScheduler._datetime_to_label_safe_datestring(
task.execution_date
),
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index 2a5a0df..2e6f2ba 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -22,10 +22,17 @@ is supported and no serialization need be written.
"""
import copy
+import hashlib
+import re
+import uuid
+
import kubernetes.client.models as k8s
+
from airflow.executors import Executors
-import uuid
+MAX_LABEL_LEN = 63
+
+MAX_POD_ID_LEN = 253
class PodDefaults:
"""
@@ -55,6 +62,25 @@ class PodDefaults:
)
+def make_safe_label_value(string):
+ """
+ Valid label values must be 63 characters or less and must be empty or begin and
+ end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_),
+ dots (.), and alphanumerics between.
+
+ If the label value is greater than 63 chars once made safe, or differs in any
+ way from the original value sent to this function, then we need to truncate to
+ 53 chars, and append it with a unique hash.
+ """
+ safe_label = re.sub(r"^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$", "", string)
+
+ if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
+ safe_hash = hashlib.md5(string.encode()).hexdigest()[:9]
+ safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash
+
+ return safe_label
+
+
class PodGenerator:
"""
Contains Kubernetes Airflow Worker configuration logic
@@ -201,9 +227,22 @@ class PodGenerator:
if self.extract_xcom:
result = self.add_sidecar(result)
+ result.metadata.name = self.make_unique_pod_id(result.metadata.name)
return result
@staticmethod
+ def make_unique_pod_id(dag_id):
+ """
+ Kubernetes pod names must be <= 253 chars and must pass the following regex for
+ validation
+ ``^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$``
+ :param dag_id: a dag_id with only alphanumeric characters
+ :return: ``str`` valid Pod name of appropriate length
+ """
+ if not dag_id:
+ return None
+
+ @staticmethod
def add_sidecar(pod):
pod_cp = copy.deepcopy(pod)
diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py
index 47d8ed5..6fb5989 100644
--- a/airflow/kubernetes/pod_launcher.py
+++ b/airflow/kubernetes/pod_launcher.py
@@ -89,13 +89,16 @@ class PodLauncher(LoggingMixin):
if e.status != 404:
raise
- def run_pod(self, pod, startup_timeout=120, get_logs=True):
+ def start_pod(
+ self,
+ pod,
+ startup_timeout):
"""
Launches the pod synchronously and waits for completion.
- Args:
- pod (Pod):
- startup_timeout (int): Timeout for startup of the pod (if pod is pending for
- too long, considers task a failure
+
+ :param pod:
+ :param startup_timeout: Timeout for startup of the pod (if pod is pending for too long, fails task)
+ :return:
"""
resp = self.run_pod_async(pod)
curr_time = dt.now()
@@ -107,9 +110,13 @@ class PodLauncher(LoggingMixin):
time.sleep(1)
self.log.debug('Pod not yet started')
- return self._monitor_pod(pod, get_logs)
-
- def _monitor_pod(self, pod, get_logs):
+ def monitor_pod(self, pod, get_logs):
+ """
+ :param pod: pod spec that will be monitored
+ :type pod : V1Pod
+ :param get_logs: whether to read the logs locally
+ :return: Tuple[State, Optional[str]]
+ """
if get_logs:
logs = self.read_pod_logs(pod)
@@ -180,6 +187,23 @@ class PodLauncher(LoggingMixin):
wait=tenacity.wait_exponential(),
reraise=True
)
+ def read_pod_events(self, pod):
+ """Reads events from the POD"""
+ try:
+ return self._client.list_namespaced_event(
+ namespace=pod.metadata.namespace,
+ field_selector="involvedObject.name={}".format(pod.metadata.name)
+ )
+ except BaseHTTPError as e:
+ raise AirflowException(
+ 'There was an error reading the kubernetes API: {}'.format(e)
+ )
+
+ @tenacity.retry(
+ stop=tenacity.stop_after_attempt(3),
+ wait=tenacity.wait_exponential(),
+ reraise=True
+ )
def read_pod(self, pod):
"""Read POD information"""
try:
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index b89a37f..49d9a9f 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Executes task in a Kubernetes POD"""
-import warnings
import re
@@ -80,6 +79,12 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
:param cluster_context: context that points to kubernetes cluster.
Ignored when in_cluster is True. If None, current-context is used.
:type cluster_context: str
+ :param reattach_on_restart: if the scheduler dies while the pod is running, reattach and monitor
+ :type reattach_on_restart: bool
+ :param labels: labels to apply to the Pod.
+ :type labels: dict
+ :param startup_timeout_seconds: timeout in seconds to startup the pod.
+ :type startup_timeout_seconds: int
:param get_logs: get the stdout of the container as logs of the tasks.
:type get_logs: bool
:param annotations: non-identifying metadata you can attach to the Pod.
@@ -126,90 +131,11 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
"""
template_fields = ('cmds', 'arguments', 'env_vars', 'config_file')
- def execute(self, context):
- try:
- client = kube_client.get_kube_client(in_cluster=self.in_cluster,
- cluster_context=self.cluster_context,
- config_file=self.config_file)
- # Add Airflow Version to the label
- # And a label to identify that pod is launched by KubernetesPodOperator
- self.labels.update(
- {
- 'airflow_version': airflow_version.replace('+', '-'),
- 'kubernetes_pod_operator': 'True',
- }
- )
-
- pod = pod_generator.PodGenerator(
- image=self.image,
- namespace=self.namespace,
- cmds=self.cmds,
- args=self.arguments,
- labels=self.labels,
- name=self.name,
- envs=self.env_vars,
- extract_xcom=self.do_xcom_push,
- image_pull_policy=self.image_pull_policy,
- node_selectors=self.node_selectors,
- priority_class_name=self.priority_class_name,
- annotations=self.annotations,
- affinity=self.affinity,
- init_containers=self.init_containers,
- image_pull_secrets=self.image_pull_secrets,
- service_account_name=self.service_account_name,
- hostnetwork=self.hostnetwork,
- tolerations=self.tolerations,
- configmaps=self.configmaps,
- security_context=self.security_context,
- dnspolicy=self.dnspolicy,
- pod=self.full_pod_spec,
- ).gen_pod()
-
- pod = append_to_pod(
- pod,
- self.pod_runtime_info_envs +
- self.ports +
- self.resources +
- self.secrets +
- self.volumes +
- self.volume_mounts
- )
-
- self.pod = pod
-
- launcher = pod_launcher.PodLauncher(kube_client=client,
- extract_xcom=self.do_xcom_push)
-
- try:
- (final_state, result) = launcher.run_pod(
- pod,
- startup_timeout=self.startup_timeout_seconds,
- get_logs=self.get_logs)
- finally:
- if self.is_delete_operator_pod:
- launcher.delete_pod(pod)
-
- if final_state != State.SUCCESS:
- raise AirflowException(
- 'Pod returned a failure: {state}'.format(state=final_state)
- )
-
- return result
- except AirflowException as ex:
- raise AirflowException('Pod Launching failed: {error}'.format(error=ex))
-
- def _set_resources(self, resources):
- return [Resources(**resources) if resources else Resources()]
-
- def _set_name(self, name):
- validate_key(name, max_length=220)
- return re.sub(r'[^a-z0-9.-]+', '-', name.lower())
-
@apply_defaults
def __init__(self, # pylint: disable=too-many-arguments,too-many-locals
- namespace,
- image,
- name,
+ namespace=None,
+ image=None,
+ name=None,
cmds=None,
arguments=None,
ports=None,
@@ -220,15 +146,14 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
in_cluster=None,
cluster_context=None,
labels=None,
+ reattach_on_restart=True,
startup_timeout_seconds=120,
get_logs=True,
image_pull_policy='IfNotPresent',
annotations=None,
resources=None,
affinity=None,
- init_containers=None,
config_file=None,
- do_xcom_push=False,
node_selectors=None,
image_pull_secrets=None,
service_account_name='default',
@@ -239,18 +164,19 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
security_context=None,
pod_runtime_info_envs=None,
dnspolicy=None,
+ schedulername=None,
full_pod_spec=None,
+ init_containers=None,
+ log_events_on_failure=False,
+ do_xcom_push=False,
+ pod_template_file=None,
priority_class_name=None,
*args,
**kwargs):
- # https://github.com/apache/airflow/blob/2d0eff4ee4fafcf8c7978ac287a8fb968e56605f/UPDATING.md#unification-of-do_xcom_push-flag
if kwargs.get('xcom_push') is not None:
- kwargs['do_xcom_push'] = kwargs.pop('xcom_push')
- warnings.warn(
- "`xcom_push` will be deprecated. Use `do_xcom_push` instead.",
- DeprecationWarning, stacklevel=2
- )
+ raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
super(KubernetesPodOperator, self).__init__(*args, resources=None, **kwargs)
+
self.pod = None
self.do_xcom_push = do_xcom_push
self.image = image
@@ -259,16 +185,14 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
self.arguments = arguments or []
self.labels = labels or {}
self.startup_timeout_seconds = startup_timeout_seconds
- self.name = self._set_name(name)
self.env_vars = env_vars or {}
self.ports = ports or []
- self.init_containers = init_containers or []
- self.priority_class_name = priority_class_name
self.volume_mounts = volume_mounts or []
self.volumes = volumes or []
self.secrets = secrets or []
self.in_cluster = in_cluster
self.cluster_context = cluster_context
+ self.reattach_on_restart = reattach_on_restart
self.get_logs = get_logs
self.image_pull_policy = image_pull_policy
self.node_selectors = node_selectors or {}
@@ -285,4 +209,193 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
self.security_context = security_context or {}
self.pod_runtime_info_envs = pod_runtime_info_envs or []
self.dnspolicy = dnspolicy
+ self.schedulername = schedulername
self.full_pod_spec = full_pod_spec
+ self.init_containers = init_containers or []
+ self.log_events_on_failure = log_events_on_failure
+ self.priority_class_name = priority_class_name
+ self.pod_template_file = pod_template_file
+ self.name = self._set_name(name)
+
+ @staticmethod
+ def create_labels_for_pod(context):
+ """
+ Generate labels for the pod to track the pod in case of Operator crash
+
+ :param context: task context provided by airflow DAG
+ :return: dict
+ """
+ labels = {
+ 'dag_id': context['dag'].dag_id,
+ 'task_id': context['task'].task_id,
+ 'execution_date': context['ts'],
+ 'try_number': context['ti'].try_number,
+ }
+ # In the case of sub dags this is just useful
+ if context['dag'].is_subdag:
+ labels['parent_dag_id'] = context['dag'].parent_dag.dag_id
+ # Ensure that label is valid for Kube,
+ # and if not truncate/remove invalid chars and replace with short hash.
+ for label_id, label in labels.items():
+ safe_label = pod_generator.make_safe_label_value(str(label))
+ labels[label_id] = safe_label
+ return labels
+
+ def execute(self, context):
+ try:
+ if self.in_cluster is not None:
+ client = kube_client.get_kube_client(in_cluster=self.in_cluster,
+ cluster_context=self.cluster_context,
+ config_file=self.config_file)
+ else:
+ client = kube_client.get_kube_client(cluster_context=self.cluster_context,
+ config_file=self.config_file)
+
+ # Add combination of labels to uniquely identify a running pod
+ labels = self.create_labels_for_pod(context)
+
+ label_selector = self._get_pod_identifying_label_string(labels)
+
+ pod_list = client.list_namespaced_pod(self.namespace, label_selector=label_selector)
+
+ if len(pod_list.items) > 1:
+ raise AirflowException(
+ 'More than one pod running with labels: '
+ '{label_selector}'.format(label_selector=label_selector))
+
+ launcher = pod_launcher.PodLauncher(kube_client=client, extract_xcom=self.do_xcom_push)
+
+ if len(pod_list.items) == 1 and \
+ self._try_numbers_do_not_match(context, pod_list.items[0]) and \
+ self.reattach_on_restart:
+ self.log.info("found a running pod with labels %s but a different try_number"
+ "Will attach to this pod and monitor instead of starting new one", labels)
+ final_state, _, result = self.create_new_pod_for_operator(labels, launcher)
+ elif len(pod_list.items) == 1:
+ self.log.info("found a running pod with labels %s."
+ "Will monitor this pod instead of starting new one", labels)
+ final_state, result = self.monitor_launched_pod(launcher, pod_list[0])
+ else:
+ final_state, _, result = self.create_new_pod_for_operator(labels, launcher)
+ if final_state != State.SUCCESS:
+ raise AirflowException(
+ 'Pod returned a failure: {state}'.format(state=final_state))
+ return result
+ except AirflowException as ex:
+ raise AirflowException('Pod Launching failed: {error}'.format(error=ex))
+
+ @staticmethod
+ def _get_pod_identifying_label_string(labels):
+ filtered_labels = {label_id: label for label_id, label in labels.items() if label_id != 'try_number'}
+ return ','.join([label_id + '=' + label for label_id, label in sorted(filtered_labels.items())])
+
+ @staticmethod
+ def _try_numbers_do_not_match(context, pod):
+ return pod.metadata.labels['try_number'] != context['ti'].try_number
+
+ @staticmethod
+ def _set_resources(resources):
+ if not resources:
+ return []
+ return [Resources(**resources)]
+
+ def _set_name(self, name):
+ if self.pod_template_file or self.full_pod_spec:
+ return None
+ validate_key(name, max_length=220)
+ return re.sub(r'[^a-z0-9.-]+', '-', name.lower())
+
+ def create_new_pod_for_operator(self, labels, launcher):
+ """
+ Creates a new pod and monitors for duration of task
+
+ @param labels: labels used to track pod
+ @param launcher: pod launcher that will manage launching and monitoring pods
+ @return:
+ """
+ if not (self.full_pod_spec or self.pod_template_file):
+ # Add Airflow Version to the label
+ # And a label to identify that pod is launched by KubernetesPodOperator
+ self.labels.update(
+ {
+ 'airflow_version': airflow_version.replace('+', '-'),
+ 'kubernetes_pod_operator': 'True',
+ }
+ )
+ self.labels.update(labels)
+ pod = pod_generator.PodGenerator(
+ image=self.image,
+ namespace=self.namespace,
+ cmds=self.cmds,
+ args=self.arguments,
+ labels=self.labels,
+ name=self.name,
+ envs=self.env_vars,
+ extract_xcom=self.do_xcom_push,
+ image_pull_policy=self.image_pull_policy,
+ node_selectors=self.node_selectors,
+ annotations=self.annotations,
+ affinity=self.affinity,
+ image_pull_secrets=self.image_pull_secrets,
+ service_account_name=self.service_account_name,
+ hostnetwork=self.hostnetwork,
+ tolerations=self.tolerations,
+ configmaps=self.configmaps,
+ security_context=self.security_context,
+ dnspolicy=self.dnspolicy,
+ schedulername=self.schedulername,
+ init_containers=self.init_containers,
+ restart_policy='Never',
+ priority_class_name=self.priority_class_name,
+ pod=self.full_pod_spec,
+ ).gen_pod()
+
+ # noinspection PyTypeChecker
+ pod = append_to_pod(
+ pod,
+ self.pod_runtime_info_envs + # type: ignore
+ self.ports + # type: ignore
+ self.resources + # type: ignore
+ self.secrets + # type: ignore
+ self.volumes + # type: ignore
+ self.volume_mounts # type: ignore
+ )
+
+ self.pod = pod
+
+ try:
+ launcher.start_pod(
+ pod,
+ startup_timeout=self.startup_timeout_seconds)
+ final_state, result = launcher.monitor_pod(pod=pod, get_logs=self.get_logs)
+ except AirflowException:
+ if self.log_events_on_failure:
+ for event in launcher.read_pod_events(pod).items:
+ self.log.error("Pod Event: %s - %s", event.reason, event.message)
+ raise
+ finally:
+ if self.is_delete_operator_pod:
+ launcher.delete_pod(pod)
+ return final_state, pod, result
+
+ def monitor_launched_pod(self, launcher, pod):
+ """
+ Montitors a pod to completion that was created by a previous KubernetesPodOperator
+
+ @param launcher: pod launcher that will manage launching and monitoring pods
+ :param pod: podspec used to find pod using k8s API
+ :return:
+ """
+ try:
+ (final_state, result) = launcher.monitor_pod(pod, get_logs=self.get_logs)
+ finally:
+ if self.is_delete_operator_pod:
+ launcher.delete_pod(pod)
+ if final_state != State.SUCCESS:
+ if self.log_events_on_failure:
+ for event in launcher.read_pod_events(pod).items:
+ self.log.error("Pod Event: %s - %s", event.reason, event.message)
+ raise AirflowException(
+ 'Pod returned a failure: {state}'.format(state=final_state)
+ )
+ return final_state, result
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py
index 814318f..ddd6dce 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -21,14 +21,12 @@ import os
import shutil
import unittest
-from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator
-from tests.compat import mock, patch
-
import kubernetes.client.models as k8s
import pendulum
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException
+from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator
from airflow.exceptions import AirflowException
from airflow.kubernetes import kube_client
from airflow.kubernetes.pod import Port
@@ -38,9 +36,9 @@ from airflow.kubernetes.secret import Secret
from airflow.kubernetes.volume import Volume
from airflow.kubernetes.volume_mount import VolumeMount
from airflow.models import DAG, TaskInstance
-
from airflow.utils import timezone
from airflow.version import version as airflow_version
+from tests.compat import mock, patch
# noinspection DuplicatedCode
@@ -74,11 +72,10 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
'labels': {
'foo': 'bar', 'kubernetes_pod_operator': 'True',
'airflow_version': airflow_version.replace('+', '-'),
- # 'execution_date': '2016-01-01T0100000100-a2f50a31f',
- # 'dag_id': 'dag',
- # 'task_id': 'task',
- # 'try_number': '1'
- },
+ 'execution_date': '2016-01-01T0100000100-a2f50a31f',
+ 'dag_id': 'dag',
+ 'task_id': 'task',
+ 'try_number': '1'},
},
'spec': {
'affinity': {},
@@ -113,6 +110,19 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
client = kube_client.get_kube_client(in_cluster=False)
client.delete_collection_namespaced_pod(namespace="default")
+ def create_context(self, task):
+ dag = DAG(dag_id="dag")
+ tzinfo = pendulum.timezone("Europe/Amsterdam")
+ execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
+ task_instance = TaskInstance(task=task,
+ execution_date=execution_date)
+ return {
+ "dag": dag,
+ "ts": execution_date.isoformat(),
+ "task": task,
+ "ti": task_instance,
+ }
+
def test_do_xcom_push_defaults_false(self):
new_config_path = '/tmp/kube_config'
old_config_path = os.path.expanduser('~/.kube/config')
@@ -149,11 +159,98 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
do_xcom_push=False,
config_file=new_config_path,
)
- context = create_context(k)
+ context = self.create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.assertEqual(self.expected_pod, actual_pod)
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
+ @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
+ def test_config_path(self, client_mock, monitor_mock, start_mock): # pylint: disable=unused-argument
+ from airflow.utils.state import State
+
+ file_path = "/tmp/fake_file"
+ k = KubernetesPodOperator(
+ namespace='default',
+ image="ubuntu:16.04",
+ cmds=["bash", "-cx"],
+ arguments=["echo 10"],
+ labels={"foo": "bar"},
+ name="test",
+ task_id="task",
+ in_cluster=False,
+ do_xcom_push=False,
+ config_file=file_path,
+ cluster_context='default',
+ )
+ monitor_mock.return_value = (State.SUCCESS, None)
+ client_mock.list_namespaced_pod.return_value = []
+ context = self.create_context(k)
+ k.execute(context=context)
+ client_mock.assert_called_once_with(
+ in_cluster=False,
+ cluster_context='default',
+ config_file=file_path,
+ )
+
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
+ @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
+ def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock, start_mock):
+ from airflow.utils.state import State
+
+ fake_pull_secrets = "fakeSecret"
+ k = KubernetesPodOperator(
+ namespace='default',
+ image="ubuntu:16.04",
+ cmds=["bash", "-cx"],
+ arguments=["echo 10"],
+ labels={"foo": "bar"},
+ name="test",
+ task_id="task",
+ in_cluster=False,
+ do_xcom_push=False,
+ image_pull_secrets=fake_pull_secrets,
+ cluster_context='default',
+ )
+ monitor_mock.return_value = (State.SUCCESS, None)
+ context = self.create_context(k)
+ k.execute(context=context)
+ self.assertEqual(
+ start_mock.call_args[0][0].spec.image_pull_secrets,
+ [k8s.V1LocalObjectReference(name=fake_pull_secrets)]
+ )
+
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.delete_pod")
+ @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
+ def test_pod_delete_even_on_launcher_error(
+ self,
+ mock_client,
+ delete_pod_mock,
+ monitor_pod_mock,
+ start_pod_mock): # pylint: disable=unused-argument
+ k = KubernetesPodOperator(
+ namespace='default',
+ image="ubuntu:16.04",
+ cmds=["bash", "-cx"],
+ arguments=["echo 10"],
+ labels={"foo": "bar"},
+ name="test",
+ task_id="task",
+ in_cluster=False,
+ do_xcom_push=False,
+ cluster_context='default',
+ is_delete_operator_pod=True,
+ )
+ monitor_pod_mock.side_effect = AirflowException('fake failure')
+ with self.assertRaises(AirflowException):
+ context = self.create_context(k)
+ k.execute(context=context)
+ assert delete_pod_mock.called
+
def test_working_pod(self):
k = KubernetesPodOperator(
namespace='default',
@@ -185,7 +282,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
do_xcom_push=False,
is_delete_operator_pod=True,
)
- context = create_context(k)
+ context = self.create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
@@ -204,7 +301,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
do_xcom_push=False,
hostnetwork=True,
)
- context = create_context(k)
+ context = self.create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['hostNetwork'] = True
@@ -226,7 +323,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
hostnetwork=True,
dnspolicy=dns_policy
)
- context = create_context(k)
+ context = self.create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['hostNetwork'] = True
@@ -234,6 +331,28 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
self.assertEqual(self.expected_pod['metadata']['labels'], actual_pod['metadata']['labels'])
+ def test_pod_schedulername(self):
+ scheduler_name = "default-scheduler"
+ k = KubernetesPodOperator(
+ namespace="default",
+ image="ubuntu:16.04",
+ cmds=["bash", "-cx"],
+ arguments=["echo 10"],
+ labels={"foo": "bar"},
+ name="test",
+ task_id="task",
+ in_cluster=False,
+ do_xcom_push=False,
+ schedulername=scheduler_name
+ )
+ context = self.create_context(k)
+ k.execute(context)
+ actual_pod = self.api_client.sanitize_for_serialization(k.pod)
+ self.expected_pod['spec']['schedulerName'] = scheduler_name
+ self.assertEqual(self.expected_pod, actual_pod)
+ self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
+ self.assertEqual(self.expected_pod['metadata']['labels'], actual_pod['metadata']['labels'])
+
def test_pod_node_selectors(self):
node_selectors = {
'beta.kubernetes.io/os': 'linux'
@@ -275,7 +394,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
do_xcom_push=False,
resources=resources,
)
- context = create_context(k)
+ context = self.create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['resources'] = {
@@ -342,7 +461,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
do_xcom_push=False,
ports=[port],
)
- context = create_context(k)
+ context = self.create_context(k)
k.execute(context=context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['ports'] = [{
@@ -564,9 +683,10 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
self.expected_pod['spec']['containers'].append(container)
self.assertEqual(self.expected_pod, actual_pod)
- @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
- @patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_envs_from_configmaps(self, mock_client, mock_run):
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
+ @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
+ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
# GIVEN
from airflow.utils.state import State
@@ -585,19 +705,20 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
configmaps=[configmap],
)
# THEN
- mock_run.return_value = (State.SUCCESS, None)
- context = create_context(k)
+ mock_monitor.return_value = (State.SUCCESS, None)
+ context = self.create_context(k)
k.execute(context)
self.assertEqual(
- mock_run.call_args[0][0].spec.containers[0].env_from,
+ mock_start.call_args[0][0].spec.containers[0].env_from,
[k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(
name=configmap
))]
)
- @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
- @patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_envs_from_secrets(self, mock_client, mock_run):
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
+ @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
+ def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
# GIVEN
from airflow.utils.state import State
secret_ref = 'secret_name'
@@ -616,11 +737,11 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
do_xcom_push=False,
)
# THEN
- mock_run.return_value = (State.SUCCESS, None)
- context = create_context(k)
+ monitor_mock.return_value = (State.SUCCESS, None)
+ context = self.create_context(k)
k.execute(context)
self.assertEqual(
- mock_run.call_args[0][0].spec.containers[0].env_from,
+ start_mock.call_args[0][0].spec.containers[0].env_from,
[k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
name=secret_ref
))]
@@ -704,12 +825,14 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
}]
self.assertEqual(self.expected_pod, actual_pod)
- @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@patch("airflow.kubernetes.kube_client.get_kube_client")
def test_pod_priority_class_name(
self,
mock_client,
- run_mock): # pylint: disable=unused-argument
+ monitor_mock,
+ start_mock): # pylint: disable=unused-argument
"""Test ability to assign priorityClassName to pod
"""
@@ -729,8 +852,8 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
priority_class_name=priority_class_name,
)
- run_mock.return_value = (State.SUCCESS, None)
- context = create_context(k)
+ monitor_mock.return_value = (State.SUCCESS, None)
+ context = self.create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['priorityClassName'] = priority_class_name
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index 77299f6..cf7ba54 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -33,6 +33,8 @@ try:
from airflow.configuration import conf # noqa: F401
from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler
from airflow.executors.kubernetes_executor import KubernetesExecutor
+ from airflow.kubernetes import pod_generator
+ from airflow.kubernetes.pod_generator import PodGenerator
from airflow.utils.state import State
except ImportError:
AirflowKubernetesScheduler = None # type: ignore
@@ -87,24 +89,24 @@ class TestAirflowKubernetesScheduler(unittest.TestCase):
'kubernetes python package is not installed')
def test_create_pod_id(self):
for dag_id, task_id in self._cases():
- pod_name = AirflowKubernetesScheduler._create_pod_id(dag_id, task_id)
+ pod_name = PodGenerator.make_unique_pod_id(
+ AirflowKubernetesScheduler._create_pod_id(dag_id, task_id)
+ )
self.assertTrue(self._is_valid_pod_id(pod_name))
def test_make_safe_label_value(self):
for dag_id, task_id in self._cases():
- safe_dag_id = AirflowKubernetesScheduler._make_safe_label_value(dag_id)
+ safe_dag_id = pod_generator.make_safe_label_value(dag_id)
self.assertTrue(self._is_safe_label_value(safe_dag_id))
- safe_task_id = AirflowKubernetesScheduler._make_safe_label_value(task_id)
+ safe_task_id = pod_generator.make_safe_label_value(task_id)
self.assertTrue(self._is_safe_label_value(safe_task_id))
- id = "my_dag_id"
self.assertEqual(
- id,
- AirflowKubernetesScheduler._make_safe_label_value(id)
+ dag_id,
+ pod_generator.make_safe_label_value(dag_id)
)
- id = "my_dag_id_" + "a" * 64
self.assertEqual(
"my_dag_id_" + "a" * 43 + "-0ce114c45",
- AirflowKubernetesScheduler._make_safe_label_value(id)
+ pod_generator.make_safe_label_value(dag_id)
)
@unittest.skipIf(AirflowKubernetesScheduler is None,