You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/06/29 13:21:26 UTC
[airflow] 20/37: [AIRFLOW-5413] Refactor worker config (#7114)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 0a853380fcc5dee7533c70bb0adfde0ebca7b420
Author: davlum <da...@gmail.com>
AuthorDate: Thu Jan 9 15:39:05 2020 -0500
[AIRFLOW-5413] Refactor worker config (#7114)
(cherry picked from commit 51f262c65afd7eaecc54661a3b5c4e533feecff8)
---
.github/workflows/ci.yml | 2 +-
.../contrib/operators/kubernetes_pod_operator.py | 4 +-
airflow/executors/kubernetes_executor.py | 13 +-
airflow/kubernetes/pod_generator.py | 259 ++++++++--
airflow/kubernetes/worker_configuration.py | 18 +-
tests/executors/test_kubernetes_executor.py | 100 ++--
tests/kubernetes/test_pod_generator.py | 541 ++++++++++++++++++---
tests/kubernetes/test_worker_configuration.py | 95 +++-
8 files changed, 838 insertions(+), 194 deletions(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index fb16aaf..67b9e50 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -144,7 +144,7 @@ jobs:
- name: Cache virtualenv for kubernetes testing
uses: actions/cache@v2
env:
- cache-name: cache-kubernetes-tests-virtualenv-v2
+ cache-name: cache-kubernetes-tests-virtualenv-v3
with:
path: .build/.kubernetes_venv
key: "${{ env.cache-name }}-${{ github.job }}-\
diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py
index 8adb131..d439eda 100644
--- a/airflow/contrib/operators/kubernetes_pod_operator.py
+++ b/airflow/contrib/operators/kubernetes_pod_operator.py
@@ -367,11 +367,11 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
pod,
startup_timeout=self.startup_timeout_seconds)
final_state, result = launcher.monitor_pod(pod=pod, get_logs=self.get_logs)
- except AirflowException:
+ except AirflowException as ex:
if self.log_events_on_failure:
for event in launcher.read_pod_events(pod).items:
self.log.error("Pod Event: %s - %s", event.reason, event.message)
- raise
+ raise AirflowException('Pod Launching failed: {error}'.format(error=ex))
finally:
if self.is_delete_operator_pod:
launcher.delete_pod(pod)
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index d458d7a..74e504e 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -71,7 +71,7 @@ class KubeConfig:
self.kubernetes_section, "worker_container_image_pull_policy"
)
self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {})
- self.kube_annotations = configuration_dict.get('kubernetes_annotations', {})
+ self.kube_annotations = configuration_dict.get('kubernetes_annotations', {}) or None
self.kube_labels = configuration_dict.get('kubernetes_labels', {})
self.delete_worker_pods = conf.getboolean(
self.kubernetes_section, 'delete_worker_pods')
@@ -357,7 +357,7 @@ class AirflowKubernetesScheduler(LoggingMixin):
self.log.debug("Kubernetes using namespace %s", self.namespace)
self.kube_client = kube_client
self.launcher = PodLauncher(kube_client=self.kube_client)
- self.worker_configuration = WorkerConfiguration(kube_config=self.kube_config)
+ self.worker_configuration_pod = WorkerConfiguration(kube_config=self.kube_config).as_pod()
self._manager = multiprocessing.Manager()
self.watcher_queue = self._manager.Queue()
self.worker_uuid = worker_uuid
@@ -393,19 +393,20 @@ class AirflowKubernetesScheduler(LoggingMixin):
if command[0:2] != ["airflow", "run"]:
raise ValueError('The command must start with ["airflow", "run"].')
- config_pod = self.worker_configuration.make_pod(
+ pod = PodGenerator.construct_pod(
namespace=self.namespace,
worker_uuid=self.worker_uuid,
pod_id=self._create_pod_id(dag_id, task_id),
dag_id=pod_generator.make_safe_label_value(dag_id),
task_id=pod_generator.make_safe_label_value(task_id),
try_number=try_number,
- execution_date=self._datetime_to_label_safe_datestring(execution_date),
- airflow_command=command
+ date=self._datetime_to_label_safe_datestring(execution_date),
+ command=command,
+ kube_executor_config=kube_executor_config,
+ worker_config=self.worker_configuration_pod
)
# Reconcile the pod generated by the Operator and the Pod
# generated by the .cfg file
- pod = PodGenerator.reconcile_pods(config_pod, kube_executor_config)
self.log.debug("Kubernetes running for command %s", command)
self.log.debug("Kubernetes launching image %s", pod.spec.containers[0].image)
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index a614f41..bf0cedf 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -28,7 +28,7 @@ import uuid
import kubernetes.client.models as k8s
-from airflow.executors import Executors
+from airflow.version import version as airflow_version
MAX_LABEL_LEN = 63
@@ -87,28 +87,59 @@ class PodGenerator:
Contains Kubernetes Airflow Worker configuration logic
Represents a kubernetes pod and manages execution of a single pod.
+ Any configuration that is container specific gets applied to
+ the first container in the list of containers.
+
+ Parameters with a type of `kubernetes.client.models.*`/`k8s.*` can
+ often be replaced with their dictionary equivalent, for example the output of
+ `sanitize_for_serialization`.
+
:param image: The docker image
- :type image: str
+ :type image: Optional[str]
+ :param name: name in the metadata section (not the container name)
+ :type name: Optional[str]
+ :param namespace: pod namespace
+ :type namespace: Optional[str]
+ :param volume_mounts: list of kubernetes volumes mounts
+ :type volume_mounts: Optional[List[Union[k8s.V1VolumeMount, dict]]]
:param envs: A dict containing the environment variables
- :type envs: Dict[str, str]
- :param cmds: The command to be run on the pod
- :type cmds: List[str]
- :param secrets: Secrets to be launched to the pod
- :type secrets: List[airflow.kubernetes.models.secret.Secret]
+ :type envs: Optional[Dict[str, str]]
+ :param cmds: The command to be run on the first container
+ :type cmds: Optional[List[str]]
+ :param args: The arguments to be run on the pod
+ :type args: Optional[List[str]]
+ :param labels: labels for the pod metadata
+ :type labels: Optional[Dict[str, str]]
+ :param node_selectors: node selectors for the pod
+ :type node_selectors: Optional[Dict[str, str]]
+ :param ports: list of ports. Applies to the first container.
+ :type ports: Optional[List[Union[k8s.V1ContainerPort, dict]]]
+ :param volumes: Volumes to be attached to the first container
+ :type volumes: Optional[List[Union[k8s.V1Volume, dict]]]
:param image_pull_policy: Specify a policy to cache or always pull an image
:type image_pull_policy: str
+ :param restart_policy: The restart policy of the pod
+ :type restart_policy: str
:param image_pull_secrets: Any image pull secrets to be given to the pod.
If more than one secret is required, provide a comma separated list:
secret_a,secret_b
:type image_pull_secrets: str
+ :param init_containers: A list of init containers
+ :type init_containers: Optional[List[k8s.V1Container]]
+ :param service_account_name: Identity for processes that run in a Pod
+ :type service_account_name: Optional[str]
+ :param resources: Resource requirements for the first containers
+ :type resources: Optional[Union[k8s.V1ResourceRequirements, dict]]
+ :param annotations: annotations for the pod
+ :type annotations: Optional[Dict[str, str]]
:param affinity: A dict containing a group of affinity scheduling rules
- :type affinity: dict
+ :type affinity: Optional[dict]
:param hostnetwork: If True enable host networking on the pod
:type hostnetwork: bool
:param tolerations: A list of kubernetes tolerations
- :type tolerations: list
+ :type tolerations: Optional[list]
:param security_context: A dict containing the security context for the pod
- :type security_context: dict
+ :type security_context: Optional[Union[k8s.V1PodSecurityContext, dict]]
:param configmaps: Any configmap refs to envfrom.
If more than one configmap is required, provide a comma separated list
configmap_a,configmap_b
@@ -117,11 +148,13 @@ class PodGenerator:
:type dnspolicy: str
:param pod: The fully specified pod.
:type pod: kubernetes.client.models.V1Pod
+ :param extract_xcom: Whether to bring up a container for xcom
+ :type extract_xcom: bool
"""
def __init__(
self,
- image,
+ image=None,
name=None,
namespace=None,
volume_mounts=None,
@@ -225,10 +258,11 @@ class PodGenerator:
result.metadata = self.metadata
result.spec.containers = [self.container]
+ result.metadata.name = self.make_unique_pod_id(result.metadata.name)
+
if self.extract_xcom:
result = self.add_sidecar(result)
- result.metadata.name = self.make_unique_pod_id(result.metadata.name)
return result
@staticmethod
@@ -252,8 +286,9 @@ class PodGenerator:
@staticmethod
def add_sidecar(pod):
pod_cp = copy.deepcopy(pod)
-
+ pod_cp.spec.volumes = pod.spec.volumes or []
pod_cp.spec.volumes.insert(0, PodDefaults.VOLUME)
+ pod_cp.spec.containers[0].volume_mounts = pod_cp.spec.containers[0].volume_mounts or []
pod_cp.spec.containers[0].volume_mounts.insert(0, PodDefaults.VOLUME_MOUNT)
pod_cp.spec.containers.append(PodDefaults.SIDECAR_CONTAINER)
@@ -262,7 +297,7 @@ class PodGenerator:
@staticmethod
def from_obj(obj):
if obj is None:
- return k8s.V1Pod()
+ return None
if isinstance(obj, PodGenerator):
return obj.gen_pod()
@@ -272,7 +307,12 @@ class PodGenerator:
'Cannot convert a non-dictionary or non-PodGenerator '
'object into a KubernetesExecutorConfig')
- namespaced = obj.get(Executors.KubernetesExecutor, {})
+ # We do not want to extract constant here from ExecutorLoader because it is just
+ # A name in dictionary rather than executor selection mechanism and it causes cyclic import
+ namespaced = obj.get("KubernetesExecutor", {})
+
+ if not namespaced:
+ return None
resources = namespaced.get('resources')
@@ -348,46 +388,159 @@ class PodGenerator:
should be preserved from base, the volumes appended to and
the other fields overwritten.
"""
+ if client_pod is None:
+ return base_pod
client_pod_cp = copy.deepcopy(client_pod)
+ client_pod_cp.spec = PodGenerator.reconcile_specs(base_pod.spec, client_pod_cp.spec)
- def merge_objects(base_obj, client_obj):
- for base_key in base_obj.to_dict().keys():
- base_val = getattr(base_obj, base_key, None)
- if not getattr(client_obj, base_key, None) and base_val:
- setattr(client_obj, base_key, base_val)
-
- def extend_object_field(base_obj, client_obj, field_name):
- base_obj_field = getattr(base_obj, field_name, None)
- client_obj_field = getattr(client_obj, field_name, None)
- if not base_obj_field:
- return
- if not client_obj_field:
- setattr(client_obj, field_name, base_obj_field)
- return
- appended_fields = base_obj_field + client_obj_field
- setattr(client_obj, field_name, appended_fields)
-
- # Values at the pod and metadata should be overwritten where they exist,
- # but certain values at the spec and container level must be conserved.
- base_container = base_pod.spec.containers[0]
- client_container = client_pod_cp.spec.containers[0]
-
- extend_object_field(base_container, client_container, 'volume_mounts')
- extend_object_field(base_container, client_container, 'env')
- extend_object_field(base_container, client_container, 'env_from')
- extend_object_field(base_container, client_container, 'ports')
- extend_object_field(base_container, client_container, 'volume_devices')
- client_container.command = base_container.command
- client_container.args = base_container.args
- merge_objects(base_pod.spec.containers[0], client_pod_cp.spec.containers[0])
- # Just append any additional containers from the base pod
- client_pod_cp.spec.containers.extend(base_pod.spec.containers[1:])
-
- merge_objects(base_pod.metadata, client_pod_cp.metadata)
-
- extend_object_field(base_pod.spec, client_pod_cp.spec, 'volumes')
- merge_objects(base_pod.spec, client_pod_cp.spec)
- merge_objects(base_pod, client_pod_cp)
+ client_pod_cp.metadata = merge_objects(base_pod.metadata, client_pod_cp.metadata)
+ client_pod_cp = merge_objects(base_pod, client_pod_cp)
return client_pod_cp
+
+ @staticmethod
+ def reconcile_specs(base_spec,
+ client_spec):
+ """
+ :param base_spec: has the base attributes which are overwritten if they exist
+ in the client_spec and remain if they do not exist in the client_spec
+ :type base_spec: k8s.V1PodSpec
+ :param client_spec: the spec that the client wants to create.
+ :type client_spec: k8s.V1PodSpec
+ :return: the merged specs
+ """
+ if base_spec and not client_spec:
+ return base_spec
+ if not base_spec and client_spec:
+ return client_spec
+ elif client_spec and base_spec:
+ client_spec.containers = PodGenerator.reconcile_containers(
+ base_spec.containers, client_spec.containers
+ )
+ merged_spec = extend_object_field(base_spec, client_spec, 'volumes')
+ return merge_objects(base_spec, merged_spec)
+
+ return None
+
+ @staticmethod
+ def reconcile_containers(base_containers,
+ client_containers):
+ """
+ :param base_containers: has the base attributes which are overwritten if they exist
+ in the client_containers and remain if they do not exist in the client_containers
+ :type base_containers: List[k8s.V1Container]
+ :param client_containers: the containers that the client wants to create.
+ :type client_containers: List[k8s.V1Container]
+ :return: the merged containers
+
+ The runs recursively over the list of containers.
+ """
+ if not base_containers:
+ return client_containers
+ if not client_containers:
+ return base_containers
+
+ client_container = client_containers[0]
+ base_container = base_containers[0]
+ client_container = extend_object_field(base_container, client_container, 'volume_mounts')
+ client_container = extend_object_field(base_container, client_container, 'env')
+ client_container = extend_object_field(base_container, client_container, 'env_from')
+ client_container = extend_object_field(base_container, client_container, 'ports')
+ client_container = extend_object_field(base_container, client_container, 'volume_devices')
+ client_container = merge_objects(base_container, client_container)
+
+ return [client_container] + PodGenerator.reconcile_containers(
+ base_containers[1:], client_containers[1:]
+ )
+
+ @staticmethod
+ def construct_pod(
+ dag_id,
+ task_id,
+ pod_id,
+ try_number,
+ date,
+ command,
+ kube_executor_config,
+ worker_config,
+ namespace,
+ worker_uuid
+ ):
+ """
+ Construct a pod by gathering and consolidating the configuration from 3 places:
+ - airflow.cfg
+ - executor_config
+ - dynamic arguments
+ """
+ dynamic_pod = PodGenerator(
+ namespace=namespace,
+ image='',
+ labels={
+ 'airflow-worker': worker_uuid,
+ 'dag_id': dag_id,
+ 'task_id': task_id,
+ 'execution_date': date,
+ 'try_number': str(try_number),
+ 'airflow_version': airflow_version.replace('+', '-'),
+ 'kubernetes_executor': 'True',
+ },
+ cmds=command,
+ name=pod_id
+ ).gen_pod()
+
+ # Reconcile the pod generated by the Operator and the Pod
+ # generated by the .cfg file
+ pod_with_executor_config = PodGenerator.reconcile_pods(worker_config,
+ kube_executor_config)
+ # Reconcile that pod with the dynamic fields.
+ return PodGenerator.reconcile_pods(pod_with_executor_config, dynamic_pod)
+
+
+def merge_objects(base_obj, client_obj):
+ """
+ :param base_obj: has the base attributes which are overwritten if they exist
+ in the client_obj and remain if they do not exist in the client_obj
+ :param client_obj: the object that the client wants to create.
+ :return: the merged objects
+ """
+ if not base_obj:
+ return client_obj
+ if not client_obj:
+ return base_obj
+
+ client_obj_cp = copy.deepcopy(client_obj)
+
+ for base_key in base_obj.to_dict().keys():
+ base_val = getattr(base_obj, base_key, None)
+ if not getattr(client_obj, base_key, None) and base_val:
+ setattr(client_obj_cp, base_key, base_val)
+ return client_obj_cp
+
+
+def extend_object_field(base_obj, client_obj, field_name):
+ """
+ :param base_obj: an object which has a property `field_name` that is a list
+ :param client_obj: an object which has a property `field_name` that is a list.
+ A copy of this object is returned with `field_name` modified
+ :param field_name: the name of the list field
+ :type field_name: str
+ :return: the client_obj with the property `field_name` being the two properties appended
+ """
+ client_obj_cp = copy.deepcopy(client_obj)
+ base_obj_field = getattr(base_obj, field_name, None)
+ client_obj_field = getattr(client_obj, field_name, None)
+
+ if (not isinstance(base_obj_field, list) and base_obj_field is not None) or \
+ (not isinstance(client_obj_field, list) and client_obj_field is not None):
+ raise ValueError("The chosen field must be a list.")
+
+ if not base_obj_field:
+ return client_obj_cp
+ if not client_obj_field:
+ setattr(client_obj_cp, field_name, base_obj_field)
+ return client_obj_cp
+
+ appended_fields = base_obj_field + client_obj_field
+ setattr(client_obj_cp, field_name, appended_fields)
+ return client_obj_cp
diff --git a/airflow/kubernetes/worker_configuration.py b/airflow/kubernetes/worker_configuration.py
index bed1ac2..3464e81 100644
--- a/airflow/kubernetes/worker_configuration.py
+++ b/airflow/kubernetes/worker_configuration.py
@@ -25,7 +25,6 @@ from airflow.kubernetes.k8s_model import append_to_pod
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.kubernetes.secret import Secret
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.version import version as airflow_version
class WorkerConfiguration(LoggingMixin):
@@ -418,23 +417,12 @@ class WorkerConfiguration(LoggingMixin):
return self.kube_config.git_dags_folder_mount_point
- def make_pod(self, namespace, worker_uuid, pod_id, dag_id, task_id, execution_date,
- try_number, airflow_command):
+ def as_pod(self):
+ """Creates POD."""
pod_generator = PodGenerator(
- namespace=namespace,
- name=pod_id,
image=self.kube_config.kube_image,
image_pull_policy=self.kube_config.kube_image_pull_policy,
- labels={
- 'airflow-worker': worker_uuid,
- 'dag_id': dag_id,
- 'task_id': task_id,
- 'execution_date': execution_date,
- 'try_number': str(try_number),
- 'airflow_version': airflow_version.replace('+', '-'),
- 'kubernetes_executor': 'True',
- },
- cmds=airflow_command,
+ image_pull_secrets=self.kube_config.image_pull_secrets,
volumes=self._get_volumes(),
volume_mounts=self._get_volume_mounts(),
init_containers=self._get_init_containers(),
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index 993c47a..2b3ed17 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -26,12 +26,12 @@ from urllib3 import HTTPResponse
from airflow.utils import timezone
from tests.compat import mock
-
+from tests.test_utils.config import conf_vars
try:
from kubernetes.client.rest import ApiException
from airflow import configuration # noqa: F401
from airflow.configuration import conf # noqa: F401
- from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler
+ from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler, KubeConfig
from airflow.executors.kubernetes_executor import KubernetesExecutor
from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
@@ -124,6 +124,56 @@ class TestAirflowKubernetesScheduler(unittest.TestCase):
self.assertEqual(datetime_obj, new_datetime_obj)
+class TestKubeConfig(unittest.TestCase):
+ def setUp(self):
+ if AirflowKubernetesScheduler is None:
+ self.skipTest("kubernetes python package is not installed")
+
+ @conf_vars({
+ ('kubernetes', 'git_ssh_known_hosts_configmap_name'): 'airflow-configmap',
+ ('kubernetes', 'git_ssh_key_secret_name'): 'airflow-secrets',
+ ('kubernetes_annotations', "iam.com/role"): "role-arn",
+ ('kubernetes_annotations', "other/annotation"): "value"
+ })
+ def test_kube_config_worker_annotations_properly_parsed(self):
+ annotations = KubeConfig().kube_annotations
+ self.assertEqual({'iam.com/role': 'role-arn', 'other/annotation': 'value'}, annotations)
+
+ @conf_vars({
+ ('kubernetes', 'git_ssh_known_hosts_configmap_name'): 'airflow-configmap',
+ ('kubernetes', 'git_ssh_key_secret_name'): 'airflow-secrets'
+ })
+ def test_kube_config_no_worker_annotations(self):
+ annotations = KubeConfig().kube_annotations
+ self.assertIsNone(annotations)
+
+ @conf_vars({
+ ('kubernetes', 'git_repo'): 'foo',
+ ('kubernetes', 'git_branch'): 'foo',
+ ('kubernetes', 'git_dags_folder_mount_point'): 'foo',
+ ('kubernetes', 'git_sync_run_as_user'): '0',
+ })
+ def test_kube_config_git_sync_run_as_user_root(self):
+ self.assertEqual(KubeConfig().git_sync_run_as_user, 0)
+
+ @conf_vars({
+ ('kubernetes', 'git_repo'): 'foo',
+ ('kubernetes', 'git_branch'): 'foo',
+ ('kubernetes', 'git_dags_folder_mount_point'): 'foo',
+ })
+ def test_kube_config_git_sync_run_as_user_not_present(self):
+ self.assertEqual(KubeConfig().git_sync_run_as_user, 65533)
+
+ @conf_vars({
+ ('kubernetes', 'git_repo'): 'foo',
+ ('kubernetes', 'git_branch'): 'foo',
+ ('kubernetes', 'git_dags_folder_mount_point'): 'foo',
+ ('kubernetes', 'git_sync_run_as_user'): '',
+ })
+ def test_kube_config_git_sync_run_as_user_empty_string(self):
+ self.assertEqual(KubeConfig().git_sync_run_as_user, '')
+
+
class TestKubernetesExecutor(unittest.TestCase):
"""
Tests if an ApiException from the Kube Client will cause the task to
@@ -136,44 +186,45 @@ class TestKubernetesExecutor(unittest.TestCase):
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watcher):
# When a quota is exceeded this is the ApiException we get
- r = HTTPResponse(
+ response = HTTPResponse(
body='{"kind": "Status", "apiVersion": "v1", "metadata": {}, "status": "Failure", '
'"message": "pods \\"podname\\" is forbidden: exceeded quota: compute-resources, '
'requested: limits.memory=4Gi, used: limits.memory=6508Mi, limited: limits.memory=10Gi", '
'"reason": "Forbidden", "details": {"name": "podname", "kind": "pods"}, "code": 403}')
- r.status = 403
- r.reason = "Forbidden"
+ response.status = 403
+ response.reason = "Forbidden"
# A mock kube_client that throws errors when making a pod
mock_kube_client = mock.patch('kubernetes.client.CoreV1Api', autospec=True)
mock_kube_client.create_namespaced_pod = mock.MagicMock(
- side_effect=ApiException(http_resp=r))
+ side_effect=ApiException(http_resp=response))
mock_get_kube_client.return_value = mock_kube_client
mock_api_client = mock.MagicMock()
mock_api_client.sanitize_for_serialization.return_value = {}
mock_kube_client.api_client = mock_api_client
- kubernetesExecutor = KubernetesExecutor()
- kubernetesExecutor.start()
+ kubernetes_executor = KubernetesExecutor()
+ kubernetes_executor.start()
# Execute a task while the Api Throws errors
try_number = 1
- kubernetesExecutor.execute_async(key=('dag', 'task', datetime.utcnow(), try_number),
- command=['airflow', 'run', 'true', 'some_parameter'],
- executor_config={})
- kubernetesExecutor.sync()
- kubernetesExecutor.sync()
+ kubernetes_executor.execute_async(key=('dag', 'task', datetime.utcnow(), try_number),
+ queue=None,
+ command=['airflow', 'run', 'command'],
+ executor_config={})
+ kubernetes_executor.sync()
+ kubernetes_executor.sync()
assert mock_kube_client.create_namespaced_pod.called
- self.assertFalse(kubernetesExecutor.task_queue.empty())
+ self.assertFalse(kubernetes_executor.task_queue.empty())
# Disable the ApiException
mock_kube_client.create_namespaced_pod.side_effect = None
# Execute the task without errors should empty the queue
- kubernetesExecutor.sync()
+ kubernetes_executor.sync()
assert mock_kube_client.create_namespaced_pod.called
- self.assertTrue(kubernetesExecutor.task_queue.empty())
+ self.assertTrue(kubernetes_executor.task_queue.empty())
@mock.patch('airflow.executors.kubernetes_executor.KubeConfig')
@mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor.sync')
@@ -187,22 +238,19 @@ class TestKubernetesExecutor(unittest.TestCase):
mock.call('executor.running_tasks', mock.ANY)]
mock_stats_gauge.assert_has_calls(calls)
- @mock.patch('airflow.executors.kubernetes_executor.KubeConfig')
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
- def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_watcher, mock_kube_config):
+ def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_watcher):
executor = KubernetesExecutor()
executor.start()
key = ('dag_id', 'task_id', 'ex_time', 'try_number1')
executor._change_state(key, State.RUNNING, 'pod_id', 'default')
self.assertTrue(executor.event_buffer[key] == State.RUNNING)
- @mock.patch('airflow.executors.kubernetes_executor.KubeConfig')
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
@mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod')
- def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher,
- mock_kube_config):
+ def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher):
executor = KubernetesExecutor()
executor.start()
test_time = timezone.utcnow()
@@ -211,12 +259,10 @@ class TestKubernetesExecutor(unittest.TestCase):
self.assertTrue(executor.event_buffer[key] == State.SUCCESS)
mock_delete_pod.assert_called_once_with('pod_id', 'default')
- @mock.patch('airflow.executors.kubernetes_executor.KubeConfig')
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
@mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod')
- def test_change_state_failed(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher,
- mock_kube_config):
+ def test_change_state_failed(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher):
executor = KubernetesExecutor()
executor.kube_config.delete_worker_pods = False
executor.kube_config.delete_worker_pods_on_failure = False
@@ -227,12 +273,11 @@ class TestKubernetesExecutor(unittest.TestCase):
self.assertTrue(executor.event_buffer[key] == State.FAILED)
mock_delete_pod.assert_not_called()
- @mock.patch('airflow.executors.kubernetes_executor.KubeConfig')
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
@mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod')
def test_change_state_skip_pod_deletion(self, mock_delete_pod, mock_get_kube_client,
- mock_kubernetes_job_watcher, mock_kube_config):
+ mock_kubernetes_job_watcher):
test_time = timezone.utcnow()
executor = KubernetesExecutor()
executor.kube_config.delete_worker_pods = False
@@ -243,12 +288,11 @@ class TestKubernetesExecutor(unittest.TestCase):
self.assertTrue(executor.event_buffer[key] == State.SUCCESS)
mock_delete_pod.assert_not_called()
- @mock.patch('airflow.executors.kubernetes_executor.KubeConfig')
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
@mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod')
def test_change_state_failed_pod_deletion(self, mock_delete_pod, mock_get_kube_client,
- mock_kubernetes_job_watcher, mock_kube_config):
+ mock_kubernetes_job_watcher):
executor = KubernetesExecutor()
executor.kube_config.delete_worker_pods_on_failure = True
diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py
index 30839e7..a9a3aa5 100644
--- a/tests/kubernetes/test_pod_generator.py
+++ b/tests/kubernetes/test_pod_generator.py
@@ -20,15 +20,17 @@ from tests.compat import mock
import uuid
import kubernetes.client.models as k8s
from kubernetes.client import ApiClient
-from airflow.kubernetes.secret import Secret
-from airflow.kubernetes.pod_generator import PodGenerator, PodDefaults
-from airflow.kubernetes.pod import Resources
+
from airflow.kubernetes.k8s_model import append_to_pod
+from airflow.kubernetes.pod import Resources
+from airflow.kubernetes.pod_generator import PodDefaults, PodGenerator, extend_object_field, merge_objects
+from airflow.kubernetes.secret import Secret
class TestPodGenerator(unittest.TestCase):
def setUp(self):
+ self.static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48')
self.envs = {
'ENVIRONMENT': 'prod',
'LOG_LEVEL': 'warning'
@@ -41,9 +43,23 @@ class TestPodGenerator(unittest.TestCase):
# This should produce a single secret mounted in env
Secret('env', 'TARGET', 'secret_b', 'source_b'),
]
+ self.labels = {
+ 'airflow-worker': 'uuid',
+ 'dag_id': 'dag_id',
+ 'execution_date': 'date',
+ 'task_id': 'task_id',
+ 'try_number': '3',
+ 'airflow_version': mock.ANY,
+ 'kubernetes_executor': 'True'
+ }
+ self.metadata = {
+ 'labels': self.labels,
+ 'name': 'pod_id-' + self.static_uuid.hex,
+ 'namespace': 'namespace'
+ }
+
self.resources = Resources('1Gi', 1, '2Gi', 2, 1)
self.k8s_client = ApiClient()
- self.static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48')
self.expected = {
'apiVersion': 'v1',
'kind': 'Pod',
@@ -171,9 +187,9 @@ class TestPodGenerator(unittest.TestCase):
fs_group=2000,
),
ports=[k8s.V1ContainerPort(name='foo', container_port=1234)],
- configmaps=['configmap_a', 'configmap_b']
+ configmaps=['configmap_a', 'configmap_b'],
+ extract_xcom=True
)
- pod_generator.extract_xcom = True
result = pod_generator.gen_pod()
result = append_to_pod(result, self.secrets)
result = self.resources.attach_to_pod(result)
@@ -253,79 +269,452 @@ class TestPodGenerator(unittest.TestCase):
}
}, result)
- def test_reconcile_pods(self):
- with mock.patch('uuid.uuid4') as mock_uuid:
- mock_uuid.return_value = self.static_uuid
- base_pod = PodGenerator(
- image='image1',
- name='name1',
- envs={'key1': 'val1'},
- cmds=['/bin/command1.sh', 'arg1'],
- ports=k8s.V1ContainerPort(name='port', container_port=2118),
- volumes=[{
- 'hostPath': {'path': '/tmp/'},
- 'name': 'example-kubernetes-test-volume1'
- }],
- volume_mounts=[{
- 'mountPath': '/foo/',
- 'name': 'example-kubernetes-test-volume1'
- }],
- ).gen_pod()
-
- mutator_pod = PodGenerator(
- envs={'key2': 'val2'},
- image='',
- name='name2',
- cmds=['/bin/command2.sh', 'arg2'],
- volumes=[{
- 'hostPath': {'path': '/tmp/'},
- 'name': 'example-kubernetes-test-volume2'
- }],
- volume_mounts=[{
- 'mountPath': '/foo/',
- 'name': 'example-kubernetes-test-volume2'
- }]
- ).gen_pod()
-
- result = PodGenerator.reconcile_pods(base_pod, mutator_pod)
- result = self.k8s_client.sanitize_for_serialization(result)
- self.assertEqual(result, {
- 'apiVersion': 'v1',
- 'kind': 'Pod',
- 'metadata': {'name': 'name2-' + self.static_uuid.hex},
- 'spec': {
- 'containers': [{
- 'args': [],
- 'command': ['/bin/command1.sh', 'arg1'],
- 'env': [
- {'name': 'key1', 'value': 'val1'},
- {'name': 'key2', 'value': 'val2'}
- ],
- 'envFrom': [],
- 'image': 'image1',
- 'imagePullPolicy': 'IfNotPresent',
- 'name': 'base',
- 'ports': {
- 'containerPort': 2118,
- 'name': 'port',
- },
- 'volumeMounts': [{
- 'mountPath': '/foo/',
- 'name': 'example-kubernetes-test-volume1'
- }, {
- 'mountPath': '/foo/',
- 'name': 'example-kubernetes-test-volume2'
- }]
+ @mock.patch('uuid.uuid4')
+ def test_reconcile_pods_empty_mutator_pod(self, mock_uuid):
+ mock_uuid.return_value = self.static_uuid
+ base_pod = PodGenerator(
+ image='image1',
+ name='name1',
+ envs={'key1': 'val1'},
+ cmds=['/bin/command1.sh', 'arg1'],
+ ports=[k8s.V1ContainerPort(name='port', container_port=2118)],
+ volumes=[{
+ 'hostPath': {'path': '/tmp/'},
+ 'name': 'example-kubernetes-test-volume1'
+ }],
+ volume_mounts=[{
+ 'mountPath': '/foo/',
+ 'name': 'example-kubernetes-test-volume1'
+ }],
+ ).gen_pod()
+
+ mutator_pod = None
+ name = 'name1-' + self.static_uuid.hex
+
+ base_pod.metadata.name = name
+
+ result = PodGenerator.reconcile_pods(base_pod, mutator_pod)
+ self.assertEqual(base_pod, result)
+
+ mutator_pod = k8s.V1Pod()
+ result = PodGenerator.reconcile_pods(base_pod, mutator_pod)
+ self.assertEqual(base_pod, result)
+
+ @mock.patch('uuid.uuid4')
+ def test_reconcile_pods(self, mock_uuid):
+ mock_uuid.return_value = self.static_uuid
+ base_pod = PodGenerator(
+ image='image1',
+ name='name1',
+ envs={'key1': 'val1'},
+ cmds=['/bin/command1.sh', 'arg1'],
+ ports=[k8s.V1ContainerPort(name='port', container_port=2118)],
+ volumes=[{
+ 'hostPath': {'path': '/tmp/'},
+ 'name': 'example-kubernetes-test-volume1'
+ }],
+ volume_mounts=[{
+ 'mountPath': '/foo/',
+ 'name': 'example-kubernetes-test-volume1'
+ }],
+ ).gen_pod()
+
+ mutator_pod = PodGenerator(
+ envs={'key2': 'val2'},
+ image='',
+ name='name2',
+ cmds=['/bin/command2.sh', 'arg2'],
+ volumes=[{
+ 'hostPath': {'path': '/tmp/'},
+ 'name': 'example-kubernetes-test-volume2'
+ }],
+ volume_mounts=[{
+ 'mountPath': '/foo/',
+ 'name': 'example-kubernetes-test-volume2'
+ }]
+ ).gen_pod()
+
+ result = PodGenerator.reconcile_pods(base_pod, mutator_pod)
+ result = self.k8s_client.sanitize_for_serialization(result)
+ self.assertEqual(result, {
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {'name': 'name2-' + self.static_uuid.hex},
+ 'spec': {
+ 'containers': [{
+ 'args': [],
+ 'command': ['/bin/command2.sh', 'arg2'],
+ 'env': [
+ {'name': 'key1', 'value': 'val1'},
+ {'name': 'key2', 'value': 'val2'}
+ ],
+ 'envFrom': [],
+ 'image': 'image1',
+ 'imagePullPolicy': 'IfNotPresent',
+ 'name': 'base',
+ 'ports': [{
+ 'containerPort': 2118,
+ 'name': 'port',
}],
- 'hostNetwork': False,
- 'imagePullSecrets': [],
- 'restartPolicy': 'Never',
- 'volumes': [{
- 'hostPath': {'path': '/tmp/'},
+ 'volumeMounts': [{
+ 'mountPath': '/foo/',
'name': 'example-kubernetes-test-volume1'
}, {
- 'hostPath': {'path': '/tmp/'},
+ 'mountPath': '/foo/',
'name': 'example-kubernetes-test-volume2'
}]
+ }],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'restartPolicy': 'Never',
+ 'volumes': [{
+ 'hostPath': {'path': '/tmp/'},
+ 'name': 'example-kubernetes-test-volume1'
+ }, {
+ 'hostPath': {'path': '/tmp/'},
+ 'name': 'example-kubernetes-test-volume2'
+ }]
+ }
+ })
+
+ @mock.patch('uuid.uuid4')
+ def test_construct_pod_empty_worker_config(self, mock_uuid):
+ mock_uuid.return_value = self.static_uuid
+ executor_config = k8s.V1Pod(
+ spec=k8s.V1PodSpec(
+ containers=[
+ k8s.V1Container(
+ name='',
+ resources=k8s.V1ResourceRequirements(
+ limits={
+ 'cpu': '1m',
+ 'memory': '1G'
+ }
+ )
+ )
+ ]
+ )
+ )
+ worker_config = k8s.V1Pod()
+
+ result = PodGenerator.construct_pod(
+ 'dag_id',
+ 'task_id',
+ 'pod_id',
+ 3,
+ 'date',
+ ['command'],
+ executor_config,
+ worker_config,
+ 'namespace',
+ 'uuid',
+ )
+ sanitized_result = self.k8s_client.sanitize_for_serialization(result)
+
+ self.assertEqual({
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': self.metadata,
+ 'spec': {
+ 'containers': [{
+ 'args': [],
+ 'command': ['command'],
+ 'env': [],
+ 'envFrom': [],
+ 'imagePullPolicy': 'IfNotPresent',
+ 'name': 'base',
+ 'ports': [],
+ 'resources': {
+ 'limits': {
+ 'cpu': '1m',
+ 'memory': '1G'
+ }
+ },
+ 'volumeMounts': []
+ }],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'restartPolicy': 'Never',
+ 'volumes': []
+ }
+ }, sanitized_result)
+
+ @mock.patch('uuid.uuid4')
+ def test_construct_pod_empty_execuctor_config(self, mock_uuid):
+ mock_uuid.return_value = self.static_uuid
+ worker_config = k8s.V1Pod(
+ spec=k8s.V1PodSpec(
+ containers=[
+ k8s.V1Container(
+ name='',
+ resources=k8s.V1ResourceRequirements(
+ limits={
+ 'cpu': '1m',
+ 'memory': '1G'
+ }
+ )
+ )
+ ]
+ )
+ )
+ executor_config = None
+
+ result = PodGenerator.construct_pod(
+ 'dag_id',
+ 'task_id',
+ 'pod_id',
+ 3,
+ 'date',
+ ['command'],
+ executor_config,
+ worker_config,
+ 'namespace',
+ 'uuid',
+ )
+ sanitized_result = self.k8s_client.sanitize_for_serialization(result)
+
+ self.assertEqual({
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': self.metadata,
+ 'spec': {
+ 'containers': [{
+ 'args': [],
+ 'command': ['command'],
+ 'env': [],
+ 'envFrom': [],
+ 'imagePullPolicy': 'IfNotPresent',
+ 'name': 'base',
+ 'ports': [],
+ 'resources': {
+ 'limits': {
+ 'cpu': '1m',
+ 'memory': '1G'
+ }
+ },
+ 'volumeMounts': []
+ }],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'restartPolicy': 'Never',
+ 'volumes': []
+ }
+ }, sanitized_result)
+
+ @mock.patch('uuid.uuid4')
+ def test_construct_pod(self, mock_uuid):
+ mock_uuid.return_value = self.static_uuid
+ worker_config = k8s.V1Pod(
+ metadata=k8s.V1ObjectMeta(
+ name='gets-overridden-by-dynamic-args',
+ annotations={
+ 'should': 'stay'
}
- })
+ ),
+ spec=k8s.V1PodSpec(
+ containers=[
+ k8s.V1Container(
+ name='doesnt-override',
+ resources=k8s.V1ResourceRequirements(
+ limits={
+ 'cpu': '1m',
+ 'memory': '1G'
+ }
+ ),
+ security_context=k8s.V1SecurityContext(
+ run_as_user=1
+ )
+ )
+ ]
+ )
+ )
+ executor_config = k8s.V1Pod(
+ spec=k8s.V1PodSpec(
+ containers=[
+ k8s.V1Container(
+ name='doesnt-override-either',
+ resources=k8s.V1ResourceRequirements(
+ limits={
+ 'cpu': '2m',
+ 'memory': '2G'
+ }
+ )
+ )
+ ]
+ )
+ )
+
+ result = PodGenerator.construct_pod(
+ 'dag_id',
+ 'task_id',
+ 'pod_id',
+ 3,
+ 'date',
+ ['command'],
+ executor_config,
+ worker_config,
+ 'namespace',
+ 'uuid',
+ )
+ sanitized_result = self.k8s_client.sanitize_for_serialization(result)
+
+ self.metadata.update({'annotations': {'should': 'stay'}})
+
+ self.assertEqual({
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': self.metadata,
+ 'spec': {
+ 'containers': [{
+ 'args': [],
+ 'command': ['command'],
+ 'env': [],
+ 'envFrom': [],
+ 'imagePullPolicy': 'IfNotPresent',
+ 'name': 'base',
+ 'ports': [],
+ 'resources': {
+ 'limits': {
+ 'cpu': '2m',
+ 'memory': '2G'
+ }
+ },
+ 'volumeMounts': [],
+ 'securityContext': {'runAsUser': 1}
+ }],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'restartPolicy': 'Never',
+ 'volumes': []
+ }
+ }, sanitized_result)
+
+ def test_merge_objects_empty(self):
+ annotations = {'foo1': 'bar1'}
+ base_obj = k8s.V1ObjectMeta(annotations=annotations)
+ client_obj = None
+ res = merge_objects(base_obj, client_obj)
+ self.assertEqual(base_obj, res)
+
+ client_obj = k8s.V1ObjectMeta()
+ res = merge_objects(base_obj, client_obj)
+ self.assertEqual(base_obj, res)
+
+ client_obj = k8s.V1ObjectMeta(annotations=annotations)
+ base_obj = None
+ res = merge_objects(base_obj, client_obj)
+ self.assertEqual(client_obj, res)
+
+ base_obj = k8s.V1ObjectMeta()
+ res = merge_objects(base_obj, client_obj)
+ self.assertEqual(client_obj, res)
+
+ def test_merge_objects(self):
+ base_annotations = {'foo1': 'bar1'}
+ base_labels = {'foo1': 'bar1'}
+ client_annotations = {'foo2': 'bar2'}
+ base_obj = k8s.V1ObjectMeta(
+ annotations=base_annotations,
+ labels=base_labels
+ )
+ client_obj = k8s.V1ObjectMeta(annotations=client_annotations)
+ res = merge_objects(base_obj, client_obj)
+ client_obj.labels = base_labels
+ self.assertEqual(client_obj, res)
+
+ def test_extend_object_field_empty(self):
+ ports = [k8s.V1ContainerPort(container_port=1, name='port')]
+ base_obj = k8s.V1Container(name='base_container', ports=ports)
+ client_obj = k8s.V1Container(name='client_container')
+ res = extend_object_field(base_obj, client_obj, 'ports')
+ client_obj.ports = ports
+ self.assertEqual(client_obj, res)
+
+ base_obj = k8s.V1Container(name='base_container')
+ client_obj = k8s.V1Container(name='base_container', ports=ports)
+ res = extend_object_field(base_obj, client_obj, 'ports')
+ self.assertEqual(client_obj, res)
+
+ def test_extend_object_field_not_list(self):
+ base_obj = k8s.V1Container(name='base_container', image='image')
+ client_obj = k8s.V1Container(name='client_container')
+ with self.assertRaises(ValueError):
+ extend_object_field(base_obj, client_obj, 'image')
+ base_obj = k8s.V1Container(name='base_container')
+ client_obj = k8s.V1Container(name='client_container', image='image')
+ with self.assertRaises(ValueError):
+ extend_object_field(base_obj, client_obj, 'image')
+
+ def test_extend_object_field(self):
+ base_ports = [k8s.V1ContainerPort(container_port=1, name='base_port')]
+ base_obj = k8s.V1Container(name='base_container', ports=base_ports)
+ client_ports = [k8s.V1ContainerPort(container_port=1, name='client_port')]
+ client_obj = k8s.V1Container(name='client_container', ports=client_ports)
+ res = extend_object_field(base_obj, client_obj, 'ports')
+ client_obj.ports = base_ports + client_ports
+ self.assertEqual(client_obj, res)
+
+ def test_reconcile_containers_empty(self):
+ base_objs = [k8s.V1Container(name='base_container')]
+ client_objs = []
+ res = PodGenerator.reconcile_containers(base_objs, client_objs)
+ self.assertEqual(base_objs, res)
+
+ client_objs = [k8s.V1Container(name='client_container')]
+ base_objs = []
+ res = PodGenerator.reconcile_containers(base_objs, client_objs)
+ self.assertEqual(client_objs, res)
+
+ res = PodGenerator.reconcile_containers([], [])
+ self.assertEqual(res, [])
+
+ def test_reconcile_containers(self):
+ base_ports = [k8s.V1ContainerPort(container_port=1, name='base_port')]
+ base_objs = [
+ k8s.V1Container(name='base_container1', ports=base_ports),
+ k8s.V1Container(name='base_container2', image='base_image'),
+ ]
+ client_ports = [k8s.V1ContainerPort(container_port=2, name='client_port')]
+ client_objs = [
+ k8s.V1Container(name='client_container1', ports=client_ports),
+ k8s.V1Container(name='client_container2', image='client_image'),
+ ]
+ res = PodGenerator.reconcile_containers(base_objs, client_objs)
+ client_objs[0].ports = base_ports + client_ports
+ self.assertEqual(client_objs, res)
+
+ base_ports = [k8s.V1ContainerPort(container_port=1, name='base_port')]
+ base_objs = [
+ k8s.V1Container(name='base_container1', ports=base_ports),
+ k8s.V1Container(name='base_container2', image='base_image'),
+ ]
+ client_ports = [k8s.V1ContainerPort(container_port=2, name='client_port')]
+ client_objs = [
+ k8s.V1Container(name='client_container1', ports=client_ports),
+ k8s.V1Container(name='client_container2', stdin=True),
+ ]
+ res = PodGenerator.reconcile_containers(base_objs, client_objs)
+ client_objs[0].ports = base_ports + client_ports
+ client_objs[1].image = 'base_image'
+ self.assertEqual(client_objs, res)
+
+ def test_reconcile_specs_empty(self):
+ base_spec = k8s.V1PodSpec(containers=[])
+ client_spec = None
+ res = PodGenerator.reconcile_specs(base_spec, client_spec)
+ self.assertEqual(base_spec, res)
+
+ base_spec = None
+ client_spec = k8s.V1PodSpec(containers=[])
+ res = PodGenerator.reconcile_specs(base_spec, client_spec)
+ self.assertEqual(client_spec, res)
+
+ def test_reconcile_specs(self):
+ base_objs = [k8s.V1Container(name='base_container1', image='base_image')]
+ client_objs = [k8s.V1Container(name='client_container1')]
+ base_spec = k8s.V1PodSpec(priority=1, active_deadline_seconds=100, containers=base_objs)
+ client_spec = k8s.V1PodSpec(priority=2, hostname='local', containers=client_objs)
+ res = PodGenerator.reconcile_specs(base_spec, client_spec)
+ client_spec.containers = [k8s.V1Container(name='client_container1', image='base_image')]
+ client_spec.active_deadline_seconds = 100
+ self.assertEqual(client_spec, res)
diff --git a/tests/kubernetes/test_worker_configuration.py b/tests/kubernetes/test_worker_configuration.py
index 8378f9f..74009a1 100644
--- a/tests/kubernetes/test_worker_configuration.py
+++ b/tests/kubernetes/test_worker_configuration.py
@@ -17,13 +17,12 @@
#
import unittest
-import uuid
-from datetime import datetime
import six
from tests.compat import mock
from tests.test_utils.config import conf_vars
+
try:
from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler
from airflow.executors.kubernetes_executor import KubeConfig
@@ -31,6 +30,7 @@ try:
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.exceptions import AirflowConfigException
from airflow.kubernetes.secret import Secret
+ from airflow.version import version as airflow_version
import kubernetes.client.models as k8s
from kubernetes.client.api_client import ApiClient
except ImportError:
@@ -74,6 +74,11 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
}
]
+ worker_annotations_config = {
+ 'iam.amazonaws.com/role': 'role-arn',
+ 'other/annotation': 'value'
+ }
+
def setUp(self):
if AirflowKubernetesScheduler is None:
self.skipTest("kubernetes python package is not installed")
@@ -312,11 +317,39 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
self.kube_config.git_subpath = 'path'
worker_config = WorkerConfiguration(self.kube_config)
- pod = worker_config.make_pod("default", str(uuid.uuid4()), "test_pod_id", "test_dag_id",
- "test_task_id", str(datetime.utcnow()), 1, "bash -c 'ls /'")
+ pod = worker_config.as_pod()
self.assertEqual(0, pod.spec.security_context.run_as_user)
+ def test_make_pod_assert_labels(self):
+ # Tests the pod created has all the expected labels set
+ self.kube_config.dags_folder = 'dags'
+
+ worker_config = WorkerConfiguration(self.kube_config)
+ pod = PodGenerator.construct_pod(
+ "test_dag_id",
+ "test_task_id",
+ "test_pod_id",
+ 1,
+ "2019-11-21 11:08:22.920875",
+ ["bash -c 'ls /'"],
+ None,
+ worker_config.as_pod(),
+ "default",
+ "sample-uuid",
+
+ )
+ expected_labels = {
+ 'airflow-worker': 'sample-uuid',
+ 'airflow_version': airflow_version.replace('+', '-'),
+ 'dag_id': 'test_dag_id',
+ 'execution_date': '2019-11-21 11:08:22.920875',
+ 'kubernetes_executor': 'True',
+ 'task_id': 'test_task_id',
+ 'try_number': '1'
+ }
+ self.assertEqual(pod.metadata.labels, expected_labels)
+
def test_make_pod_git_sync_ssh_without_known_hosts(self):
# Tests the pod created with git-sync SSH authentication option is correct without known hosts
self.kube_config.airflow_configmap = 'airflow-configmap'
@@ -331,8 +364,7 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
worker_config = WorkerConfiguration(self.kube_config)
- pod = worker_config.make_pod("default", str(uuid.uuid4()), "test_pod_id", "test_dag_id",
- "test_task_id", str(datetime.utcnow()), 1, "bash -c 'ls /'")
+ pod = worker_config.as_pod()
init_containers = worker_config._get_init_containers()
git_ssh_key_file = next((x.value for x in init_containers[0].env
@@ -361,8 +393,7 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
worker_config = WorkerConfiguration(self.kube_config)
- pod = worker_config.make_pod("default", str(uuid.uuid4()), "test_pod_id", "test_dag_id",
- "test_task_id", str(datetime.utcnow()), 1, "bash -c 'ls /'")
+ pod = worker_config.as_pod()
username_env = k8s.V1EnvVar(
name='GIT_SYNC_USERNAME',
@@ -387,6 +418,29 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
self.assertIn(password_env, pod.spec.init_containers[0].env,
'The password env for git credentials did not get into the init container')
+ def test_make_pod_git_sync_rev(self):
+ # Tests the pod created with git_sync_credentials_secret will get into the init container
+ self.kube_config.git_sync_rev = 'sampletag'
+ self.kube_config.dags_volume_claim = None
+ self.kube_config.dags_volume_host = None
+ self.kube_config.dags_in_image = None
+ self.kube_config.worker_fs_group = None
+ self.kube_config.git_dags_folder_mount_point = 'dags'
+ self.kube_config.git_sync_dest = 'repo'
+ self.kube_config.git_subpath = 'path'
+
+ worker_config = WorkerConfiguration(self.kube_config)
+
+ pod = worker_config.as_pod()
+
+ rev_env = k8s.V1EnvVar(
+ name='GIT_SYNC_REV',
+ value=self.kube_config.git_sync_rev,
+ )
+
+ self.assertIn(rev_env, pod.spec.init_containers[0].env,
+ 'The git_sync_rev env did not get into the init container')
+
def test_make_pod_git_sync_ssh_with_known_hosts(self):
# Tests the pod created with git-sync SSH authentication option is correct with known hosts
self.kube_config.airflow_configmap = 'airflow-configmap'
@@ -415,11 +469,10 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
def test_make_pod_with_empty_executor_config(self):
self.kube_config.kube_affinity = self.affinity_config
self.kube_config.kube_tolerations = self.tolerations_config
+ self.kube_config.kube_annotations = self.worker_annotations_config
self.kube_config.dags_folder = 'dags'
worker_config = WorkerConfiguration(self.kube_config)
-
- pod = worker_config.make_pod("default", str(uuid.uuid4()), "test_pod_id", "test_dag_id",
- "test_task_id", str(datetime.utcnow()), 1, "bash -c 'ls /'")
+ pod = worker_config.as_pod()
self.assertTrue(pod.spec.affinity['podAntiAffinity'] is not None)
self.assertEqual('app',
@@ -431,6 +484,8 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
self.assertEqual(2, len(pod.spec.tolerations))
self.assertEqual('prod', pod.spec.tolerations[1]['key'])
+ self.assertEqual('role-arn', pod.metadata.annotations['iam.amazonaws.com/role'])
+ self.assertEqual('value', pod.metadata.annotations['other/annotation'])
def test_make_pod_with_executor_config(self):
self.kube_config.dags_folder = 'dags'
@@ -441,8 +496,7 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
tolerations=self.tolerations_config,
).gen_pod()
- pod = worker_config.make_pod("default", str(uuid.uuid4()), "test_pod_id", "test_dag_id",
- "test_task_id", str(datetime.utcnow()), 1, "bash -c 'ls /'")
+ pod = worker_config.as_pod()
result = PodGenerator.reconcile_pods(pod, config_pod)
@@ -607,3 +661,18 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
'dag_id': 'override_dag_id',
'my_kube_executor_label': 'kubernetes'
}, labels)
+
+ def test_make_pod_with_image_pull_secrets(self):
+ # Tests the pod created with image_pull_secrets actually gets that in it's config
+ self.kube_config.dags_volume_claim = None
+ self.kube_config.dags_volume_host = None
+ self.kube_config.dags_in_image = None
+ self.kube_config.git_dags_folder_mount_point = 'dags'
+ self.kube_config.git_sync_dest = 'repo'
+ self.kube_config.git_subpath = 'path'
+ self.kube_config.image_pull_secrets = 'image_pull_secret1,image_pull_secret2'
+
+ worker_config = WorkerConfiguration(self.kube_config)
+ pod = worker_config.as_pod()
+
+ self.assertEqual(2, len(pod.spec.image_pull_secrets))