You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/06/29 13:21:37 UTC
[airflow] 31/37: [AIRFLOW-5413] Allow K8S worker pod to be
configured from JSON/YAML file (#6230)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit b222e1717b8dd466bfb6880e0e079bc19f60e383
Author: Daniel Imberman <da...@astronomer.io>
AuthorDate: Fri Jun 26 17:55:00 2020 -0700
[AIRFLOW-5413] Allow K8S worker pod to be configured from JSON/YAML file (#6230)
* [AIRFLOW-5413] enable pod config from file
* Update airflow/kubernetes/pod_generator.py
Co-Authored-By: Ash Berlin-Taylor <as...@firemirror.com>
* Update airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Co-Authored-By: Ash Berlin-Taylor <as...@firemirror.com>
Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
(cherry picked from commit 967930c0cb6e2293f2a49e5c9add5aa1917f3527)
---
airflow/config_templates/config.yml | 7 +
airflow/config_templates/default_airflow.cfg | 3 +
.../contrib/operators/kubernetes_pod_operator.py | 15 ++-
.../example_kubernetes_executor_config.py | 6 +-
airflow/executors/kubernetes_executor.py | 11 +-
airflow/kubernetes/pod_generator.py | 144 ++++++++++++++-------
airflow/kubernetes/worker_configuration.py | 20 ++-
kubernetes_tests/test_kubernetes_pod_operator.py | 50 ++++++-
tests/executors/test_kubernetes_executor.py | 2 +-
tests/kubernetes/models/test_pod.py | 2 -
tests/kubernetes/models/test_secret.py | 2 -
tests/kubernetes/pod.yaml | 33 +++++
tests/kubernetes/test_pod_generator.py | 80 +++++++++---
tests/kubernetes/test_worker_configuration.py | 14 +-
14 files changed, 304 insertions(+), 85 deletions(-)
diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 9b63200..61491d8 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -1736,6 +1736,13 @@
type: string
example: ~
default: ""
+ - name: pod_template_file
+ description: |
+ Path to the YAML pod file. If set, all other kubernetes-related fields are ignored.
+ version_added: ~
+ type: string
+ example: ~
+ default: ""
- name: worker_container_tag
description: ~
version_added: ~
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index ca9de12..2cc97e2 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -801,6 +801,9 @@ verify_certs = True
[kubernetes]
# The repository, tag and imagePullPolicy of the Kubernetes Image for the Worker to Run
worker_container_repository =
+
+# Path to the YAML pod file. If set, all other kubernetes-related fields are ignored.
+pod_template_file =
worker_container_tag =
worker_container_image_pull_policy = IfNotPresent
diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py
index ce8f19c..41f0df3 100644
--- a/airflow/contrib/operators/kubernetes_pod_operator.py
+++ b/airflow/contrib/operators/kubernetes_pod_operator.py
@@ -130,8 +130,18 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
:type schedulername: str
:param full_pod_spec: The complete podSpec
:type full_pod_spec: kubernetes.client.models.V1Pod
+ :param init_containers: init container for the launched Pod
+ :type init_containers: list[kubernetes.client.models.V1Container]
+ :param log_events_on_failure: Log the pod's events if a failure occurs
+ :type log_events_on_failure: bool
+ :param do_xcom_push: If True, the content of the file
+ /airflow/xcom/return.json in the container will also be pushed to an
+ XCom when the container completes.
+ :type do_xcom_push: bool
+ :param pod_template_file: path to pod template file
+ :type pod_template_file: str
"""
- template_fields = ('cmds', 'arguments', 'env_vars', 'config_file')
+ template_fields = ('cmds', 'arguments', 'env_vars', 'config_file', 'pod_template_file')
@apply_defaults
def __init__(self, # pylint: disable=too-many-arguments,too-many-locals
@@ -215,8 +225,8 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
self.full_pod_spec = full_pod_spec
self.init_containers = init_containers or []
self.log_events_on_failure = log_events_on_failure
- self.priority_class_name = priority_class_name
self.pod_template_file = pod_template_file
+ self.priority_class_name = priority_class_name
self.name = self._set_name(name)
@staticmethod
@@ -348,6 +358,7 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
init_containers=self.init_containers,
restart_policy='Never',
schedulername=self.schedulername,
+ pod_template_file=self.pod_template_file,
priority_class_name=self.priority_class_name,
pod=self.full_pod_spec,
).gen_pod()
diff --git a/airflow/example_dags/example_kubernetes_executor_config.py b/airflow/example_dags/example_kubernetes_executor_config.py
index d740956..2e4ba00 100644
--- a/airflow/example_dags/example_kubernetes_executor_config.py
+++ b/airflow/example_dags/example_kubernetes_executor_config.py
@@ -83,14 +83,14 @@ with DAG(
}
)
- # Test that we can run tasks as a normal user
+ # Test that we can add labels to pods
third_task = PythonOperator(
task_id="non_root_task",
python_callable=print_stuff,
executor_config={
"KubernetesExecutor": {
- "securityContext": {
- "runAsUser": 1000
+ "labels": {
+ "release": "stable"
}
}
}
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index 98e3154..e014aa3 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -20,7 +20,6 @@ import json
import multiprocessing
import time
from queue import Empty
-from uuid import uuid4
import kubernetes
from dateutil import parser
@@ -72,6 +71,9 @@ class KubeConfig:
)
self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {})
self.kube_annotations = configuration_dict.get('kubernetes_annotations', {}) or None
+ self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file',
+ fallback=None)
+
self.kube_labels = configuration_dict.get('kubernetes_labels', {})
self.delete_worker_pods = conf.getboolean(
self.kubernetes_section, 'delete_worker_pods')
@@ -220,6 +222,8 @@ class KubeConfig:
return int(val)
def _validate(self):
+ if self.pod_template_file:
+ return
# TODO: use XOR for dags_volume_claim and git_dags_folder_mount_point
if not self.dags_volume_claim \
and not self.dags_volume_host \
@@ -498,10 +502,7 @@ class AirflowKubernetesScheduler(LoggingMixin):
dag_id)
safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
task_id)
- safe_uuid = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
- uuid4().hex)
- return AirflowKubernetesScheduler._make_safe_pod_id(safe_dag_id, safe_task_id,
- safe_uuid)
+ return safe_dag_id + safe_task_id
@staticmethod
def _label_safe_datestring_to_datetime(string):
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index 711b1a9..e46407b 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -24,10 +24,20 @@ is supported and no serialization need be written.
import copy
import hashlib
import re
+try:
+ from inspect import signature
+except ImportError:
+ # Python 2.7
+ from funcsigs import signature # type: ignore
+import os
import uuid
+from functools import reduce
import kubernetes.client.models as k8s
+import yaml
+from kubernetes.client.api_client import ApiClient
+from airflow.exceptions import AirflowConfigException
from airflow.version import version as airflow_version
MAX_LABEL_LEN = 63
@@ -35,10 +45,14 @@ MAX_LABEL_LEN = 63
MAX_POD_ID_LEN = 253
-class PodDefaults:
+class PodDefaults(object):
"""
Static defaults for the PodGenerator
"""
+
+ def __init__(self):
+ pass
+
XCOM_MOUNT_PATH = '/airflow/xcom'
SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar'
XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;'
@@ -82,7 +96,7 @@ def make_safe_label_value(string):
return safe_label
-class PodGenerator:
+class PodGenerator(object):
"""
Contains Kubernetes Airflow Worker configuration logic
@@ -147,9 +161,11 @@ class PodGenerator:
:param dnspolicy: Specify a dnspolicy for the pod
:type dnspolicy: str
:param schedulername: Specify a schedulername for the pod
- :type schedulername: str
- :param pod: The fully specified pod.
- :type pod: kubernetes.client.models.V1Pod
+ :type schedulername: Optional[str]
+ :param pod: The fully specified pod. Mutually exclusive with `path_or_string`
+ :type pod: Optional[kubernetes.client.models.V1Pod]
+ :param pod_template_file: Path to YAML file. Mutually exclusive with `pod`
+ :type pod_template_file: Optional[str]
:param extract_xcom: Whether to bring up a container for xcom
:type extract_xcom: bool
"""
@@ -167,8 +183,8 @@ class PodGenerator:
node_selectors=None,
ports=None,
volumes=None,
- image_pull_policy='IfNotPresent',
- restart_policy='Never',
+ image_pull_policy=None,
+ restart_policy=None,
image_pull_secrets=None,
init_containers=None,
service_account_name=None,
@@ -183,9 +199,16 @@ class PodGenerator:
schedulername=None,
priority_class_name=None,
pod=None,
+ pod_template_file=None,
extract_xcom=False,
):
- self.ud_pod = pod
+ self.validate_pod_generator_args(locals())
+
+ if pod_template_file:
+ self.ud_pod = self.deserialize_model_file(pod_template_file)
+ else:
+ self.ud_pod = pod
+
self.pod = k8s.V1Pod()
self.pod.api_version = 'v1'
self.pod.kind = 'Pod'
@@ -348,37 +371,7 @@ class PodGenerator:
'iam.cloud.google.com/service-account': gcp_service_account_key
})
- pod_spec_generator = PodGenerator(
- image=namespaced.get('image'),
- envs=namespaced.get('env'),
- cmds=namespaced.get('cmds'),
- args=namespaced.get('args'),
- labels=namespaced.get('labels'),
- node_selectors=namespaced.get('node_selectors'),
- name=namespaced.get('name'),
- ports=namespaced.get('ports'),
- volumes=namespaced.get('volumes'),
- volume_mounts=namespaced.get('volume_mounts'),
- namespace=namespaced.get('namespace'),
- image_pull_policy=namespaced.get('image_pull_policy'),
- restart_policy=namespaced.get('restart_policy'),
- image_pull_secrets=namespaced.get('image_pull_secrets'),
- init_containers=namespaced.get('init_containers'),
- service_account_name=namespaced.get('service_account_name'),
- resources=resources,
- annotations=namespaced.get('annotations'),
- affinity=namespaced.get('affinity'),
- hostnetwork=namespaced.get('hostnetwork'),
- tolerations=namespaced.get('tolerations'),
- security_context=namespaced.get('security_context'),
- configmaps=namespaced.get('configmaps'),
- dnspolicy=namespaced.get('dnspolicy'),
- schedulername=namespaced.get('schedulername'),
- pod=namespaced.get('pod'),
- extract_xcom=namespaced.get('extract_xcom'),
- )
-
- return pod_spec_generator.gen_pod()
+ return PodGenerator(**namespaced).gen_pod()
@staticmethod
def reconcile_pods(base_pod, client_pod):
@@ -495,12 +488,73 @@ class PodGenerator:
name=pod_id
).gen_pod()
- # Reconcile the pod generated by the Operator and the Pod
- # generated by the .cfg file
- pod_with_executor_config = PodGenerator.reconcile_pods(worker_config,
- kube_executor_config)
- # Reconcile that pod with the dynamic fields.
- return PodGenerator.reconcile_pods(pod_with_executor_config, dynamic_pod)
+ # Reconcile the pods starting with the first chronologically,
+ # Pod from the airflow.cfg -> Pod from executor_config arg -> Pod from the K8s executor
+ pod_list = [worker_config, kube_executor_config, dynamic_pod]
+
+ return reduce(PodGenerator.reconcile_pods, pod_list)
+
+ @staticmethod
+ def deserialize_model_file(path):
+ """
+ :param path: Path to the file
+ :return: a kubernetes.client.models.V1Pod
+
+ Unfortunately we need access to the private method
+ ``_ApiClient__deserialize_model`` from the kubernetes client.
+ This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.
+ """
+ api_client = ApiClient()
+ if os.path.exists(path):
+ with open(path) as stream:
+ pod = yaml.safe_load(stream)
+ else:
+ pod = yaml.safe_load(path)
+
+ # pylint: disable=protected-access
+ return api_client._ApiClient__deserialize_model(pod, k8s.V1Pod)
+
+ @staticmethod
+ def validate_pod_generator_args(given_args):
+ """
+ :param given_args: The arguments passed to the PodGenerator constructor.
+ :type given_args: dict
+ :return: None
+
+ Validate that if `pod` or `pod_template_file` are set that the user is not attempting
+ to configure the pod with the other arguments.
+ """
+ pod_args = list(signature(PodGenerator).parameters.items())
+
+ def predicate(k, v):
+ """
+ :param k: an arg to PodGenerator
+ :type k: string
+ :param v: the parameter of the given arg
+ :type v: inspect.Parameter
+ :return: bool
+
+ returns True if the PodGenerator argument has no default arguments
+ or the default argument is None, and it is not one of the listed field
+ in `non_empty_fields`.
+ """
+ non_empty_fields = {
+ 'pod', 'pod_template_file', 'extract_xcom', 'service_account_name', 'image_pull_policy',
+ 'restart_policy'
+ }
+
+ return (v.default is None or v.default is v.empty) and k not in non_empty_fields
+
+ args_without_defaults = {k: given_args[k] for k, v in pod_args if predicate(k, v) and given_args[k]}
+
+ if given_args['pod'] and given_args['pod_template_file']:
+ raise AirflowConfigException("Cannot pass both `pod` and `pod_template_file` arguments")
+ if args_without_defaults and (given_args['pod'] or given_args['pod_template_file']):
+ raise AirflowConfigException(
+ "Cannot configure pod and pass either `pod` or `pod_template_file`. Fields {} passed.".format(
+ list(args_without_defaults.keys())
+ )
+ )
def merge_objects(base_obj, client_obj):
diff --git a/airflow/kubernetes/worker_configuration.py b/airflow/kubernetes/worker_configuration.py
index 9c35910..0357f9f 100644
--- a/airflow/kubernetes/worker_configuration.py
+++ b/airflow/kubernetes/worker_configuration.py
@@ -28,7 +28,12 @@ from airflow.utils.log.logging_mixin import LoggingMixin
class WorkerConfiguration(LoggingMixin):
- """Contains Kubernetes Airflow Worker configuration logic"""
+ """
+ Contains Kubernetes Airflow Worker configuration logic
+
+ :param kube_config: the kubernetes configuration from airflow.cfg
+ :type kube_config: airflow.executors.kubernetes_executor.KubeConfig
+ """
dags_volume_name = 'airflow-dags'
logs_volume_name = 'airflow-logs'
@@ -424,9 +429,12 @@ class WorkerConfiguration(LoggingMixin):
def as_pod(self):
"""Creates POD."""
- pod_generator = PodGenerator(
+ if self.kube_config.pod_template_file:
+ return PodGenerator(pod_template_file=self.kube_config.pod_template_file).gen_pod()
+
+ pod = PodGenerator(
image=self.kube_config.kube_image,
- image_pull_policy=self.kube_config.kube_image_pull_policy,
+ image_pull_policy=self.kube_config.kube_image_pull_policy or 'IfNotPresent',
image_pull_secrets=self.kube_config.image_pull_secrets,
volumes=self._get_volumes(),
volume_mounts=self._get_volume_mounts(),
@@ -436,10 +444,10 @@ class WorkerConfiguration(LoggingMixin):
tolerations=self.kube_config.kube_tolerations,
envs=self._get_environment(),
node_selectors=self.kube_config.kube_node_selectors,
- service_account_name=self.kube_config.worker_service_account_name,
- )
+ service_account_name=self.kube_config.worker_service_account_name or 'default',
+ restart_policy='Never'
+ ).gen_pod()
- pod = pod_generator.gen_pod()
pod.spec.containers[0].env_from = pod.spec.containers[0].env_from or []
pod.spec.containers[0].env_from.extend(self._get_env_from())
pod.spec.security_context = self._get_security_context()
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py
index e20324b..b6cecda 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -827,7 +827,55 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
- @patch("airflow.kubernetes.kube_client.get_kube_client")
+ @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
+ def test_pod_template_file(self, mock_client, monitor_mock, start_mock):
+ from airflow.utils.state import State
+ k = KubernetesPodOperator(
+ task_id='task',
+ pod_template_file='tests/kubernetes/pod.yaml',
+ do_xcom_push=True
+ )
+ monitor_mock.return_value = (State.SUCCESS, None)
+ context = self.create_context(k)
+ k.execute(context)
+ actual_pod = self.api_client.sanitize_for_serialization(k.pod)
+ self.assertEqual({
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {'name': mock.ANY, 'namespace': 'mem-example'},
+ 'spec': {
+ 'volumes': [{'name': 'xcom', 'emptyDir': {}}],
+ 'containers': [{
+ 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
+ 'command': ['stress'],
+ 'image': 'polinux/stress',
+ 'name': 'memory-demo-ctr',
+ 'resources': {
+ 'limits': {'memory': '200Mi'},
+ 'requests': {'memory': '100Mi'}
+ },
+ 'volumeMounts': [{
+ 'name': 'xcom',
+ 'mountPath': '/airflow/xcom'
+ }]
+ }, {
+ 'name': 'airflow-xcom-sidecar',
+ 'image': "alpine",
+ 'command': ['sh', '-c', PodDefaults.XCOM_CMD],
+ 'volumeMounts': [
+ {
+ 'name': 'xcom',
+ 'mountPath': '/airflow/xcom'
+ }
+ ],
+ 'resources': {'requests': {'cpu': '1m'}},
+ }],
+ }
+ }, actual_pod)
+
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
+ @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
def test_pod_priority_class_name(
self,
mock_client,
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index bf7bc56..3dabb78 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -33,8 +33,8 @@ try:
from airflow.configuration import conf # noqa: F401
from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler, KubeConfig
from airflow.executors.kubernetes_executor import KubernetesExecutor
- from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
+ from airflow.kubernetes import pod_generator
from airflow.utils.state import State
except ImportError:
AirflowKubernetesScheduler = None # type: ignore
diff --git a/tests/kubernetes/models/test_pod.py b/tests/kubernetes/models/test_pod.py
index b63af6d..45c32aa 100644
--- a/tests/kubernetes/models/test_pod.py
+++ b/tests/kubernetes/models/test_pod.py
@@ -59,7 +59,6 @@ class TestPod(unittest.TestCase):
'env': [],
'envFrom': [],
'image': 'airflow-worker:latest',
- 'imagePullPolicy': 'IfNotPresent',
'name': 'base',
'ports': [{
'name': 'https',
@@ -72,7 +71,6 @@ class TestPod(unittest.TestCase):
}],
'hostNetwork': False,
'imagePullSecrets': [],
- 'restartPolicy': 'Never',
'volumes': []
}
}, result)
diff --git a/tests/kubernetes/models/test_secret.py b/tests/kubernetes/models/test_secret.py
index 44ab8b3..e91ff68 100644
--- a/tests/kubernetes/models/test_secret.py
+++ b/tests/kubernetes/models/test_secret.py
@@ -100,7 +100,6 @@ class TestSecret(unittest.TestCase):
}],
'envFrom': [{'secretRef': {'name': 'secret_a'}}],
'image': 'airflow-worker:latest',
- 'imagePullPolicy': 'IfNotPresent',
'name': 'base',
'ports': [],
'volumeMounts': [{
@@ -110,7 +109,6 @@ class TestSecret(unittest.TestCase):
}],
'hostNetwork': False,
'imagePullSecrets': [],
- 'restartPolicy': 'Never',
'volumes': [{
'name': 'secretvol' + str(static_uuid),
'secret': {'secretName': 'secret_b'}
diff --git a/tests/kubernetes/pod.yaml b/tests/kubernetes/pod.yaml
new file mode 100644
index 0000000..cd419ed
--- /dev/null
+++ b/tests/kubernetes/pod.yaml
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+---
+apiVersion: v1
+kind: Pod
+metadata:
+ name: memory-demo
+ namespace: mem-example
+spec:
+ containers:
+ - name: memory-demo-ctr
+ image: polinux/stress
+ resources:
+ limits:
+ memory: "200Mi"
+ requests:
+ memory: "100Mi"
+ command: ["stress"]
+ args: ["--vm", "1", "--vm-bytes", "150M", "--vm-hang", "1"]
diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py
index ce15a8b..7d39cdc 100644
--- a/tests/kubernetes/test_pod_generator.py
+++ b/tests/kubernetes/test_pod_generator.py
@@ -21,6 +21,7 @@ import uuid
import kubernetes.client.models as k8s
from kubernetes.client import ApiClient
+from airflow.exceptions import AirflowConfigException
from airflow.kubernetes.k8s_model import append_to_pod
from airflow.kubernetes.pod import Resources
from airflow.kubernetes.pod_generator import PodDefaults, PodGenerator, extend_object_field, merge_objects
@@ -31,6 +32,24 @@ class TestPodGenerator(unittest.TestCase):
def setUp(self):
self.static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48')
+ self.deserialize_result = {
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {'name': 'memory-demo', 'namespace': 'mem-example'},
+ 'spec': {
+ 'containers': [{
+ 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
+ 'command': ['stress'],
+ 'image': 'polinux/stress',
+ 'name': 'memory-demo-ctr',
+ 'resources': {
+ 'limits': {'memory': '200Mi'},
+ 'requests': {'memory': '100Mi'}
+ }
+ }]
+ }
+ }
+
self.envs = {
'ENVIRONMENT': 'prod',
'LOG_LEVEL': 'warning'
@@ -77,7 +96,6 @@ class TestPodGenerator(unittest.TestCase):
'command': [
'sh', '-c', 'echo Hello Kubernetes!'
],
- 'imagePullPolicy': 'IfNotPresent',
'env': [{
'name': 'ENVIRONMENT',
'value': 'prod'
@@ -126,7 +144,6 @@ class TestPodGenerator(unittest.TestCase):
'readOnly': True
}]
}],
- 'restartPolicy': 'Never',
'volumes': [{
'name': 'secretvol' + str(self.static_uuid),
'secret': {
@@ -172,7 +189,7 @@ class TestPodGenerator(unittest.TestCase):
result_dict['spec']['containers'][0]['envFrom'].sort(
key=lambda x: list(x.values())[0]['name']
)
- self.assertDictEqual(result_dict, self.expected)
+ self.assertDictEqual(self.expected, result_dict)
@mock.patch('uuid.uuid4')
def test_gen_pod_extract_xcom(self, mock_uuid):
@@ -238,9 +255,6 @@ class TestPodGenerator(unittest.TestCase):
"name": "example-kubernetes-test-volume",
},
],
- "securityContext": {
- "runAsUser": 1000
- }
}
})
result = self.k8s_client.sanitize_for_serialization(result)
@@ -264,6 +278,7 @@ class TestPodGenerator(unittest.TestCase):
'name': 'example-kubernetes-test-volume'
}],
}],
+ 'hostNetwork': False,
'imagePullSecrets': [],
'volumes': [{
'hostPath': {'path': '/tmp/'},
@@ -339,7 +354,7 @@ class TestPodGenerator(unittest.TestCase):
result = PodGenerator.reconcile_pods(base_pod, mutator_pod)
result = self.k8s_client.sanitize_for_serialization(result)
- self.assertEqual(result, {
+ self.assertEqual({
'apiVersion': 'v1',
'kind': 'Pod',
'metadata': {'name': 'name2-' + self.static_uuid.hex},
@@ -353,7 +368,6 @@ class TestPodGenerator(unittest.TestCase):
],
'envFrom': [],
'image': 'image1',
- 'imagePullPolicy': 'IfNotPresent',
'name': 'base',
'ports': [{
'containerPort': 2118,
@@ -369,7 +383,6 @@ class TestPodGenerator(unittest.TestCase):
}],
'hostNetwork': False,
'imagePullSecrets': [],
- 'restartPolicy': 'Never',
'volumes': [{
'hostPath': {'path': '/tmp/'},
'name': 'example-kubernetes-test-volume1'
@@ -378,7 +391,7 @@ class TestPodGenerator(unittest.TestCase):
'name': 'example-kubernetes-test-volume2'
}]
}
- })
+ }, result)
@mock.patch('uuid.uuid4')
def test_construct_pod_empty_worker_config(self, mock_uuid):
@@ -424,7 +437,6 @@ class TestPodGenerator(unittest.TestCase):
'command': ['command'],
'env': [],
'envFrom': [],
- 'imagePullPolicy': 'IfNotPresent',
'name': 'base',
'ports': [],
'resources': {
@@ -437,7 +449,6 @@ class TestPodGenerator(unittest.TestCase):
}],
'hostNetwork': False,
'imagePullSecrets': [],
- 'restartPolicy': 'Never',
'volumes': []
}
}, sanitized_result)
@@ -486,7 +497,6 @@ class TestPodGenerator(unittest.TestCase):
'command': ['command'],
'env': [],
'envFrom': [],
- 'imagePullPolicy': 'IfNotPresent',
'name': 'base',
'ports': [],
'resources': {
@@ -499,7 +509,6 @@ class TestPodGenerator(unittest.TestCase):
}],
'hostNetwork': False,
'imagePullSecrets': [],
- 'restartPolicy': 'Never',
'volumes': []
}
}, sanitized_result)
@@ -573,7 +582,6 @@ class TestPodGenerator(unittest.TestCase):
'command': ['command'],
'env': [],
'envFrom': [],
- 'imagePullPolicy': 'IfNotPresent',
'name': 'base',
'ports': [],
'resources': {
@@ -587,7 +595,6 @@ class TestPodGenerator(unittest.TestCase):
}],
'hostNetwork': False,
'imagePullSecrets': [],
- 'restartPolicy': 'Never',
'volumes': []
}
}, sanitized_result)
@@ -721,3 +728,44 @@ class TestPodGenerator(unittest.TestCase):
client_spec.containers = [k8s.V1Container(name='client_container1', image='base_image')]
client_spec.active_deadline_seconds = 100
self.assertEqual(client_spec, res)
+
+ def test_deserialize_model_file(self):
+ fixture = 'tests/kubernetes/pod.yaml'
+ result = PodGenerator.deserialize_model_file(fixture)
+ sanitized_res = self.k8s_client.sanitize_for_serialization(result)
+ self.assertEqual(sanitized_res, self.deserialize_result)
+
+ def test_deserialize_model_string(self):
+ fixture = """
+apiVersion: v1
+kind: Pod
+metadata:
+ name: memory-demo
+ namespace: mem-example
+spec:
+ containers:
+ - name: memory-demo-ctr
+ image: polinux/stress
+ resources:
+ limits:
+ memory: "200Mi"
+ requests:
+ memory: "100Mi"
+ command: ["stress"]
+ args: ["--vm", "1", "--vm-bytes", "150M", "--vm-hang", "1"]
+ """
+ result = PodGenerator.deserialize_model_file(fixture)
+ sanitized_res = self.k8s_client.sanitize_for_serialization(result)
+ self.assertEqual(sanitized_res, self.deserialize_result)
+
+ def test_validate_pod_generator(self):
+ with self.assertRaises(AirflowConfigException):
+ PodGenerator(image='k', pod=k8s.V1Pod())
+ with self.assertRaises(AirflowConfigException):
+ PodGenerator(pod=k8s.V1Pod(), pod_template_file='k')
+ with self.assertRaises(AirflowConfigException):
+ PodGenerator(image='k', pod_template_file='k')
+
+ PodGenerator(image='k')
+ PodGenerator(pod_template_file='tests/kubernetes/pod.yaml')
+ PodGenerator(pod=k8s.V1Pod())
diff --git a/tests/kubernetes/test_worker_configuration.py b/tests/kubernetes/test_worker_configuration.py
index 0730595..b6db2b5 100644
--- a/tests/kubernetes/test_worker_configuration.py
+++ b/tests/kubernetes/test_worker_configuration.py
@@ -17,7 +17,6 @@
#
import unittest
-
import six
from tests.compat import mock
@@ -94,6 +93,9 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
self.kube_config.dags_folder = None
self.kube_config.git_dags_folder_mount_point = None
self.kube_config.kube_labels = {'dag_id': 'original_dag_id', 'my_label': 'label_id'}
+ self.kube_config.pod_template_file = ''
+ self.kube_config.restart_policy = ''
+ self.kube_config.image_pull_policy = ''
self.api_client = ApiClient()
@conf_vars({
@@ -358,7 +360,6 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
worker_config.as_pod(),
"default",
"sample-uuid",
-
)
expected_labels = {
'airflow-worker': 'sample-uuid',
@@ -672,6 +673,15 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name='secretref_b'))
], configmaps)
+ def test_pod_template_file(self):
+ fixture = 'tests/kubernetes/pod.yaml'
+ self.kube_config.pod_template_file = fixture
+ worker_config = WorkerConfiguration(self.kube_config)
+ result = worker_config.as_pod()
+ expected = PodGenerator.deserialize_model_file(fixture)
+ expected.metadata.name = mock.ANY
+ self.assertEqual(expected, result)
+
def test_get_labels(self):
worker_config = WorkerConfiguration(self.kube_config)
labels = worker_config._get_labels({'my_kube_executor_label': 'kubernetes'}, {