You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/07 10:51:18 UTC
[airflow] branch v1-10-test updated: Fix more PodMutationHook
issues for backwards compatibility (#10084)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v1-10-test by this push:
new 21aade4 Fix more PodMutationHook issues for backwards compatibility (#10084)
21aade4 is described below
commit 21aade43ce4182d66dbd05f5b8c000ef9d9740e9
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Fri Aug 7 11:50:44 2020 +0100
Fix more PodMutationHook issues for backwards compatibility (#10084)
Co-authored-by: Daniel Imberman <da...@gmail.com>
---
airflow/contrib/executors/kubernetes_executor.py | 20 +
airflow/contrib/kubernetes/pod.py | 137 ++++++-
airflow/executors/kubernetes_executor.py | 6 +
airflow/kubernetes/pod.py | 31 +-
airflow/kubernetes/pod_generator.py | 76 +++-
airflow/kubernetes/pod_launcher.py | 64 +++-
airflow/kubernetes/pod_launcher_helper.py | 96 -----
airflow/kubernetes/volume.py | 17 +-
airflow/operators/python_operator.py | 4 +-
docs/conf.py | 1 +
kubernetes_tests/test_kubernetes_pod_operator.py | 1 -
tests/kubernetes/models/test_pod.py | 98 +++--
tests/kubernetes/models/test_volume.py | 40 ++
tests/kubernetes/test_pod_generator.py | 206 ++++++++++-
tests/kubernetes/test_pod_launcher.py | 137 ++++++-
tests/kubernetes/test_pod_launcher_helper.py | 98 -----
tests/kubernetes/test_worker_configuration.py | 7 +
tests/test_local_settings.py | 269 --------------
tests/test_local_settings/__init__.py | 16 +
tests/test_local_settings/test_local_settings.py | 441 +++++++++++++++++++++++
20 files changed, 1205 insertions(+), 560 deletions(-)
diff --git a/airflow/contrib/executors/kubernetes_executor.py b/airflow/contrib/executors/kubernetes_executor.py
new file mode 100644
index 0000000..416b2d7
--- /dev/null
+++ b/airflow/contrib/executors/kubernetes_executor.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.executors import kubernetes_executor # noqa
diff --git a/airflow/contrib/kubernetes/pod.py b/airflow/contrib/kubernetes/pod.py
index 0ab3616..0ce5800 100644
--- a/airflow/contrib/kubernetes/pod.py
+++ b/airflow/contrib/kubernetes/pod.py
@@ -19,7 +19,18 @@
import warnings
# pylint: disable=unused-import
-from airflow.kubernetes.pod import Port, Resources # noqa
+from typing import List, Union
+
+from kubernetes.client import models as k8s
+
+from airflow.kubernetes.pod import Port, Resources # noqa
+from airflow.kubernetes.volume import Volume
+from airflow.kubernetes.volume_mount import VolumeMount
+from airflow.kubernetes.secret import Secret
+
+from kubernetes.client.api_client import ApiClient
+
+api_client = ApiClient()
warnings.warn(
"This module is deprecated. Please use `airflow.kubernetes.pod`.",
@@ -120,7 +131,7 @@ class Pod(object):
self.affinity = affinity or {}
self.hostnetwork = hostnetwork or False
self.tolerations = tolerations or []
- self.security_context = security_context
+ self.security_context = security_context or {}
self.configmaps = configmaps or []
self.pod_runtime_info_envs = pod_runtime_info_envs or []
self.dnspolicy = dnspolicy
@@ -154,6 +165,7 @@ class Pod(object):
dns_policy=self.dnspolicy,
host_network=self.hostnetwork,
tolerations=self.tolerations,
+ affinity=self.affinity,
security_context=self.security_context,
)
@@ -161,17 +173,18 @@ class Pod(object):
spec=spec,
metadata=meta,
)
- for port in self.ports:
+ for port in _extract_ports(self.ports):
pod = port.attach_to_pod(pod)
- for volume in self.volumes:
+ volumes, _ = _extract_volumes_and_secrets(self.volumes, self.volume_mounts)
+ for volume in volumes:
pod = volume.attach_to_pod(pod)
- for volume_mount in self.volume_mounts:
+ for volume_mount in _extract_volume_mounts(self.volume_mounts):
pod = volume_mount.attach_to_pod(pod)
for secret in self.secrets:
pod = secret.attach_to_pod(pod)
for runtime_info in self.pod_runtime_info_envs:
pod = runtime_info.attach_to_pod(pod)
- pod = self.resources.attach_to_pod(pod)
+ pod = _extract_resources(self.resources).attach_to_pod(pod)
return pod
def as_dict(self):
@@ -182,3 +195,115 @@ class Pod(object):
res['volumes'] = [volume.as_dict() for volume in res['volumes']]
return res
+
+
+def _extract_env_vars_and_secrets(env_vars):
+ result = {}
+ env_vars = env_vars or [] # type: List[Union[k8s.V1EnvVar, dict]]
+ secrets = []
+ for env_var in env_vars:
+ if isinstance(env_var, k8s.V1EnvVar):
+ secret = _extract_env_secret(env_var)
+ if secret:
+ secrets.append(secret)
+ continue
+ env_var = api_client.sanitize_for_serialization(env_var)
+ result[env_var.get("name")] = env_var.get("value")
+ return result, secrets
+
+
+def _extract_env_secret(env_var):
+ if env_var.value_from and env_var.value_from.secret_key_ref:
+ secret = env_var.value_from.secret_key_ref # type: k8s.V1SecretKeySelector
+ name = secret.name
+ key = secret.key
+ return Secret("env", deploy_target=env_var.name, secret=name, key=key)
+ return None
+
+
+def _extract_ports(ports):
+ result = []
+ ports = ports or [] # type: List[Union[k8s.V1ContainerPort, dict]]
+ for port in ports:
+ if isinstance(port, k8s.V1ContainerPort):
+ port = api_client.sanitize_for_serialization(port)
+ port = Port(name=port.get("name"), container_port=port.get("containerPort"))
+ elif not isinstance(port, Port):
+ port = Port(name=port.get("name"), container_port=port.get("containerPort"))
+ result.append(port)
+ return result
+
+
+def _extract_resources(resources):
+ if isinstance(resources, k8s.V1ResourceRequirements):
+ requests = resources.requests
+ limits = resources.limits
+ return Resources(
+ request_memory=requests.get('memory', None),
+ request_cpu=requests.get('cpu', None),
+ request_ephemeral_storage=requests.get('ephemeral-storage', None),
+ limit_memory=limits.get('memory', None),
+ limit_cpu=limits.get('cpu', None),
+ limit_ephemeral_storage=limits.get('ephemeral-storage', None),
+ limit_gpu=limits.get('nvidia.com/gpu')
+ )
+ elif isinstance(resources, Resources):
+ return resources
+
+
+def _extract_security_context(security_context):
+ if isinstance(security_context, k8s.V1PodSecurityContext):
+ security_context = api_client.sanitize_for_serialization(security_context)
+ return security_context
+
+
+def _extract_volume_mounts(volume_mounts):
+ result = []
+ volume_mounts = volume_mounts or [] # type: List[Union[k8s.V1VolumeMount, dict]]
+ for volume_mount in volume_mounts:
+ if isinstance(volume_mount, k8s.V1VolumeMount):
+ volume_mount = api_client.sanitize_for_serialization(volume_mount)
+ volume_mount = VolumeMount(
+ name=volume_mount.get("name"),
+ mount_path=volume_mount.get("mountPath"),
+ sub_path=volume_mount.get("subPath"),
+ read_only=volume_mount.get("readOnly")
+ )
+ elif not isinstance(volume_mount, VolumeMount):
+ volume_mount = VolumeMount(
+ name=volume_mount.get("name"),
+ mount_path=volume_mount.get("mountPath"),
+ sub_path=volume_mount.get("subPath"),
+ read_only=volume_mount.get("readOnly")
+ )
+
+ result.append(volume_mount)
+ return result
+
+
+def _extract_volumes_and_secrets(volumes, volume_mounts):
+ result = []
+ volumes = volumes or [] # type: List[Union[k8s.V1Volume, dict]]
+ secrets = []
+ volume_mount_dict = {
+ volume_mount.name: volume_mount
+ for volume_mount in _extract_volume_mounts(volume_mounts)
+ }
+ for volume in volumes:
+ if isinstance(volume, k8s.V1Volume):
+ secret = _extract_volume_secret(volume, volume_mount_dict.get(volume.name, None))
+ if secret:
+ secrets.append(secret)
+ continue
+ volume = api_client.sanitize_for_serialization(volume)
+ volume = Volume(name=volume.get("name"), configs=volume)
+ if not isinstance(volume, Volume):
+ volume = Volume(name=volume.get("name"), configs=volume)
+ result.append(volume)
+ return result, secrets
+
+
+def _extract_volume_secret(volume, volume_mount):
+ if not volume.secret:
+ return None
+ return Secret("volume", volume_mount.mount_path, volume.name, volume.secret.secret_name)
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index 7bbdc98..3ad4222 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -417,6 +417,12 @@ class AirflowKubernetesScheduler(LoggingMixin):
kube_executor_config=kube_executor_config,
worker_config=self.worker_configuration_pod
)
+
+ sanitized_pod = self.launcher._client.api_client.sanitize_for_serialization(pod)
+ json_pod = json.dumps(sanitized_pod, indent=2)
+
+ self.log.debug('Pod Creation Request before mutation: \n%s', json_pod)
+
# Reconcile the pod generated by the Operator and the Pod
# generated by the .cfg file
self.log.debug("Kubernetes running for command %s", command)
diff --git a/airflow/kubernetes/pod.py b/airflow/kubernetes/pod.py
index 9e455af..67dc983 100644
--- a/airflow/kubernetes/pod.py
+++ b/airflow/kubernetes/pod.py
@@ -20,7 +20,7 @@ Classes for interacting with Kubernetes API
import copy
-import kubernetes.client.models as k8s
+from kubernetes.client import models as k8s
from airflow.kubernetes.k8s_model import K8SModel
@@ -87,18 +87,25 @@ class Resources(K8SModel):
self.request_ephemeral_storage is not None
def to_k8s_client_obj(self):
- return k8s.V1ResourceRequirements(
- limits={
- 'cpu': self.limit_cpu,
- 'memory': self.limit_memory,
- 'nvidia.com/gpu': self.limit_gpu,
- 'ephemeral-storage': self.limit_ephemeral_storage
- },
- requests={
- 'cpu': self.request_cpu,
- 'memory': self.request_memory,
- 'ephemeral-storage': self.request_ephemeral_storage}
+ limits_raw = {
+ 'cpu': self.limit_cpu,
+ 'memory': self.limit_memory,
+ 'nvidia.com/gpu': self.limit_gpu,
+ 'ephemeral-storage': self.limit_ephemeral_storage
+ }
+ requests_raw = {
+ 'cpu': self.request_cpu,
+ 'memory': self.request_memory,
+ 'ephemeral-storage': self.request_ephemeral_storage
+ }
+
+ limits = {k: v for k, v in limits_raw.items() if v}
+ requests = {k: v for k, v in requests_raw.items() if v}
+ resource_req = k8s.V1ResourceRequirements(
+ limits=limits,
+ requests=requests
)
+ return resource_req
def attach_to_pod(self, pod):
cp_pod = copy.deepcopy(pod)
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index d11c175..090e2b1 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -36,6 +36,7 @@ from functools import reduce
import kubernetes.client.models as k8s
import yaml
from kubernetes.client.api_client import ApiClient
+from airflow.contrib.kubernetes.pod import _extract_volume_mounts
from airflow.exceptions import AirflowConfigException
from airflow.version import version as airflow_version
@@ -249,7 +250,7 @@ class PodGenerator(object):
self.container.image_pull_policy = image_pull_policy
self.container.ports = ports or []
self.container.resources = resources
- self.container.volume_mounts = volume_mounts or []
+ self.container.volume_mounts = [v.to_k8s_client_obj() for v in _extract_volume_mounts(volume_mounts)]
# Pod Spec
self.spec = k8s.V1PodSpec(containers=[])
@@ -370,6 +371,11 @@ class PodGenerator(object):
requests=requests,
limits=limits
)
+ elif isinstance(resources, dict):
+ resources = k8s.V1ResourceRequirements(
+ requests=resources['requests'],
+ limits=resources['limits']
+ )
annotations = namespaced.get('annotations', {})
gcp_service_account_key = namespaced.get('gcp_service_account_key', None)
@@ -402,13 +408,36 @@ class PodGenerator(object):
client_pod_cp = copy.deepcopy(client_pod)
client_pod_cp.spec = PodGenerator.reconcile_specs(base_pod.spec, client_pod_cp.spec)
-
- client_pod_cp.metadata = merge_objects(base_pod.metadata, client_pod_cp.metadata)
+ client_pod_cp.metadata = PodGenerator.reconcile_metadata(base_pod.metadata, client_pod_cp.metadata)
client_pod_cp = merge_objects(base_pod, client_pod_cp)
return client_pod_cp
@staticmethod
+ def reconcile_metadata(base_meta, client_meta):
+ """
+ :param base_meta: has the base attributes which are overwritten if they exist
+ in the client_meta and remain if they do not exist in the client_meta
+ :type base_meta: k8s.V1ObjectMeta
+ :param client_meta: the spec that the client wants to create.
+ :type client_meta: k8s.V1ObjectMeta
+ :return: the merged specs
+ """
+ if base_meta and not client_meta:
+ return base_meta
+ if not base_meta and client_meta:
+ return client_meta
+ elif client_meta and base_meta:
+ client_meta.labels = merge_objects(base_meta.labels, client_meta.labels)
+ client_meta.annotations = merge_objects(base_meta.annotations, client_meta.annotations)
+ extend_object_field(base_meta, client_meta, 'managed_fields')
+ extend_object_field(base_meta, client_meta, 'finalizers')
+ extend_object_field(base_meta, client_meta, 'owner_references')
+ return merge_objects(base_meta, client_meta)
+
+ return None
+
+ @staticmethod
def reconcile_specs(base_spec,
client_spec):
"""
@@ -580,10 +609,17 @@ def merge_objects(base_obj, client_obj):
client_obj_cp = copy.deepcopy(client_obj)
+ if isinstance(base_obj, dict) and isinstance(client_obj_cp, dict):
+ client_obj_cp.update(base_obj)
+ return client_obj_cp
+
for base_key in base_obj.to_dict().keys():
base_val = getattr(base_obj, base_key, None)
if not getattr(client_obj, base_key, None) and base_val:
- setattr(client_obj_cp, base_key, base_val)
+ if not isinstance(client_obj_cp, dict):
+ setattr(client_obj_cp, base_key, base_val)
+ else:
+ client_obj_cp[base_key] = base_val
return client_obj_cp
@@ -610,6 +646,36 @@ def extend_object_field(base_obj, client_obj, field_name):
setattr(client_obj_cp, field_name, base_obj_field)
return client_obj_cp
- appended_fields = base_obj_field + client_obj_field
+ base_obj_set = _get_dict_from_list(base_obj_field)
+ client_obj_set = _get_dict_from_list(client_obj_field)
+
+ appended_fields = _merge_list_of_objects(base_obj_set, client_obj_set)
+
setattr(client_obj_cp, field_name, appended_fields)
return client_obj_cp
+
+
+def _merge_list_of_objects(base_obj_set, client_obj_set):
+ for k, v in base_obj_set.items():
+ if k not in client_obj_set:
+ client_obj_set[k] = v
+ else:
+ client_obj_set[k] = merge_objects(v, client_obj_set[k])
+ appended_field_keys = sorted(client_obj_set.keys())
+ appended_fields = [client_obj_set[k] for k in appended_field_keys]
+ return appended_fields
+
+
+def _get_dict_from_list(base_list):
+ """
+ :type base_list: list(Optional[dict, *to_dict])
+ """
+ result = {}
+ for obj in base_list:
+ if isinstance(obj, dict):
+ result[obj['name']] = obj
+ elif hasattr(obj, "to_dict"):
+ result[obj.name] = obj
+ else:
+ raise AirflowConfigException("Trying to merge invalid object {}".format(obj))
+ return result
diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py
index d6507df..39ed836 100644
--- a/airflow/kubernetes/pod_launcher.py
+++ b/airflow/kubernetes/pod_launcher.py
@@ -22,18 +22,21 @@ from datetime import datetime as dt
import tenacity
from kubernetes import watch, client
+from kubernetes.client import models as k8s
from kubernetes.client.rest import ApiException
from kubernetes.stream import stream as kubernetes_stream
from requests.exceptions import BaseHTTPError
from airflow import AirflowException
-from airflow.kubernetes.pod_launcher_helper import convert_to_airflow_pod
-from airflow.kubernetes.pod_generator import PodDefaults
from airflow import settings
+from airflow.contrib.kubernetes.pod import (
+ Pod, _extract_env_vars_and_secrets, _extract_volumes_and_secrets, _extract_volume_mounts,
+ _extract_ports, _extract_security_context
+)
+from airflow.kubernetes.kube_client import get_kube_client
+from airflow.kubernetes.pod_generator import PodDefaults, PodGenerator
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
-import kubernetes.client.models as k8s # noqa
-from .kube_client import get_kube_client
class PodStatus:
@@ -90,19 +93,22 @@ class PodLauncher(LoggingMixin):
def _mutate_pod_backcompat(pod):
"""Backwards compatible Pod Mutation Hook"""
try:
- settings.pod_mutation_hook(pod)
- # attempts to run pod_mutation_hook using k8s.V1Pod, if this
- # fails we attempt to run by converting pod to Old Pod
- except AttributeError:
+ dummy_pod = _convert_to_airflow_pod(pod)
+ settings.pod_mutation_hook(dummy_pod)
warnings.warn(
"Using `airflow.contrib.kubernetes.pod.Pod` is deprecated. "
"Please use `k8s.V1Pod` instead.", DeprecationWarning, stacklevel=2
)
- dummy_pod = convert_to_airflow_pod(pod)
- settings.pod_mutation_hook(dummy_pod)
dummy_pod = dummy_pod.to_v1_kubernetes_pod()
- return dummy_pod
- return pod
+
+ new_pod = PodGenerator.reconcile_pods(pod, dummy_pod)
+ except AttributeError as e:
+ try:
+ settings.pod_mutation_hook(pod)
+ return pod
+ except AttributeError as e2:
+ raise Exception([e, e2])
+ return new_pod
def delete_pod(self, pod):
"""Deletes POD"""
@@ -269,7 +275,7 @@ class PodLauncher(LoggingMixin):
return None
def process_status(self, job_id, status):
- """Process status infomration for the JOB"""
+ """Process status information for the JOB"""
status = status.lower()
if status == PodStatus.PENDING:
return State.QUEUED
@@ -284,3 +290,35 @@ class PodLauncher(LoggingMixin):
else:
self.log.info('Event: Invalid state %s on job %s', status, job_id)
return State.FAILED
+
+
+def _convert_to_airflow_pod(pod):
+ base_container = pod.spec.containers[0] # type: k8s.V1Container
+ env_vars, secrets = _extract_env_vars_and_secrets(base_container.env)
+ volumes, vol_secrets = _extract_volumes_and_secrets(pod.spec.volumes, base_container.volume_mounts)
+ secrets.extend(vol_secrets)
+ dummy_pod = Pod(
+ image=base_container.image,
+ envs=env_vars,
+ cmds=base_container.command,
+ args=base_container.args,
+ labels=pod.metadata.labels,
+ annotations=pod.metadata.annotations,
+ node_selectors=pod.spec.node_selector,
+ name=pod.metadata.name,
+ ports=_extract_ports(base_container.ports),
+ volumes=volumes,
+ volume_mounts=_extract_volume_mounts(base_container.volume_mounts),
+ namespace=pod.metadata.namespace,
+ image_pull_policy=base_container.image_pull_policy or 'IfNotPresent',
+ tolerations=pod.spec.tolerations,
+ init_containers=pod.spec.init_containers,
+ image_pull_secrets=pod.spec.image_pull_secrets,
+ resources=base_container.resources,
+ service_account_name=pod.spec.service_account_name,
+ secrets=secrets,
+ affinity=pod.spec.affinity,
+ hostnetwork=pod.spec.host_network,
+ security_context=_extract_security_context(pod.spec.security_context)
+ )
+ return dummy_pod
diff --git a/airflow/kubernetes/pod_launcher_helper.py b/airflow/kubernetes/pod_launcher_helper.py
deleted file mode 100644
index 8c9fc6e..0000000
--- a/airflow/kubernetes/pod_launcher_helper.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from typing import List, Union
-
-import kubernetes.client.models as k8s # noqa
-
-from airflow.kubernetes.volume import Volume
-from airflow.kubernetes.volume_mount import VolumeMount
-from airflow.kubernetes.pod import Port
-from airflow.contrib.kubernetes.pod import Pod
-
-
-def convert_to_airflow_pod(pod):
- base_container = pod.spec.containers[0] # type: k8s.V1Container
-
- dummy_pod = Pod(
- image=base_container.image,
- envs=_extract_env_vars(base_container.env),
- volumes=_extract_volumes(pod.spec.volumes),
- volume_mounts=_extract_volume_mounts(base_container.volume_mounts),
- labels=pod.metadata.labels,
- name=pod.metadata.name,
- namespace=pod.metadata.namespace,
- image_pull_policy=base_container.image_pull_policy or 'IfNotPresent',
- cmds=[],
- ports=_extract_ports(base_container.ports)
- )
- return dummy_pod
-
-
-def _extract_env_vars(env_vars):
- """
-
- :param env_vars:
- :type env_vars: list
- :return: result
- :rtype: dict
- """
- result = {}
- env_vars = env_vars or [] # type: List[Union[k8s.V1EnvVar, dict]]
- for env_var in env_vars:
- if isinstance(env_var, k8s.V1EnvVar):
- env_var.to_dict()
- result[env_var.get("name")] = env_var.get("value")
- return result
-
-
-def _extract_volumes(volumes):
- result = []
- volumes = volumes or [] # type: List[Union[k8s.V1Volume, dict]]
- for volume in volumes:
- if isinstance(volume, k8s.V1Volume):
- volume = volume.to_dict()
- result.append(Volume(name=volume.get("name"), configs=volume))
- return result
-
-
-def _extract_volume_mounts(volume_mounts):
- result = []
- volume_mounts = volume_mounts or [] # type: List[Union[k8s.V1VolumeMount, dict]]
- for volume_mount in volume_mounts:
- if isinstance(volume_mount, k8s.V1VolumeMount):
- volume_mount = volume_mount.to_dict()
- result.append(
- VolumeMount(
- name=volume_mount.get("name"),
- mount_path=volume_mount.get("mount_path"),
- sub_path=volume_mount.get("sub_path"),
- read_only=volume_mount.get("read_only"))
- )
-
- return result
-
-
-def _extract_ports(ports):
- result = []
- ports = ports or [] # type: List[Union[k8s.V1ContainerPort, dict]]
- for port in ports:
- if isinstance(port, k8s.V1ContainerPort):
- port = port.to_dict()
- result.append(Port(name=port.get("name"), container_port=port.get("container_port")))
- return result
diff --git a/airflow/kubernetes/volume.py b/airflow/kubernetes/volume.py
index 9d85959..9e5e5c4 100644
--- a/airflow/kubernetes/volume.py
+++ b/airflow/kubernetes/volume.py
@@ -37,9 +37,15 @@ class Volume(K8SModel):
self.configs = configs
def to_k8s_client_obj(self):
- configs = self.configs
- configs['name'] = self.name
- return configs
+ from kubernetes.client import models as k8s
+ resp = k8s.V1Volume(name=self.name)
+ for k, v in self.configs.items():
+ snake_key = Volume._convert_to_snake_case(k)
+ if hasattr(resp, snake_key):
+ setattr(resp, snake_key, v)
+ else:
+ raise AttributeError("V1Volume does not have attribute {}".format(k))
+ return resp
def attach_to_pod(self, pod):
cp_pod = copy.deepcopy(pod)
@@ -47,3 +53,8 @@ class Volume(K8SModel):
cp_pod.spec.volumes = pod.spec.volumes or []
cp_pod.spec.volumes.append(volume)
return cp_pod
+
+ # source: https://www.geeksforgeeks.org/python-program-to-convert-camel-case-string-to-snake-case/
+ @staticmethod
+ def _convert_to_snake_case(str):
+ return ''.join(['_' + i.lower() if i.isupper() else i for i in str]).lstrip('_')
diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py
index 78b6a41..392e0fc 100644
--- a/airflow/operators/python_operator.py
+++ b/airflow/operators/python_operator.py
@@ -234,8 +234,8 @@ class PythonVirtualenvOperator(PythonOperator):
python_version=None, # type: Optional[str]
use_dill=False, # type: bool
system_site_packages=True, # type: bool
- op_args=None, # type: Iterable
- op_kwargs=None, # type: Dict
+ op_args=None, # type: Optional[Iterable]
+ op_kwargs=None, # type: Optional[Dict]
provide_context=False, # type: bool
string_args=None, # type: Optional[Iterable[str]]
templates_dict=None, # type: Optional[Dict]
diff --git a/docs/conf.py b/docs/conf.py
index d18b6ea..101d050 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -220,6 +220,7 @@ exclude_patterns = [
'_api/airflow/version',
'_api/airflow/www',
'_api/airflow/www_rbac',
+ '_api/kubernetes_executor',
'_api/main',
'_api/mesos_executor',
'autoapi_templates',
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py
index b6cecda..50a1258 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -404,7 +404,6 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
'limits': {
'memory': '64Mi',
'cpu': 0.25,
- 'nvidia.com/gpu': None,
'ephemeral-storage': '2Gi'
}
}
diff --git a/tests/kubernetes/models/test_pod.py b/tests/kubernetes/models/test_pod.py
index 2e53d60..8de33bf 100644
--- a/tests/kubernetes/models/test_pod.py
+++ b/tests/kubernetes/models/test_pod.py
@@ -75,11 +75,16 @@ class TestPod(unittest.TestCase):
}
}, result)
- def test_to_v1_pod(self):
+ @mock.patch('uuid.uuid4')
+ def test_to_v1_pod(self, mock_uuid):
from airflow.contrib.kubernetes.pod import Pod as DeprecatedPod
from airflow.kubernetes.volume import Volume
from airflow.kubernetes.volume_mount import VolumeMount
+ from airflow.kubernetes.secret import Secret
from airflow.kubernetes.pod import Resources
+ import uuid
+ static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48')
+ mock_uuid.return_value = static_uuid
pod = DeprecatedPod(
image="foo",
@@ -93,7 +98,14 @@ class TestPod(unittest.TestCase):
request_cpu="100Mi",
limit_gpu="100G"
),
- volumes=[Volume(name="foo", configs={})],
+ volumes=[
+ Volume(name="foo", configs={}),
+ {"name": "bar", 'secret': {'secretName': 'volume-secret'}}
+ ],
+ secrets=[
+ Secret('env', "AIRFLOW_SECRET", 'secret_name', "airflow_config"),
+ Secret("volume", "/opt/airflow", "volume-secret", "secret-key")
+ ],
volume_mounts=[VolumeMount(name="foo", mount_path="/mnt", sub_path="/", read_only=True)]
)
@@ -103,55 +115,35 @@ class TestPod(unittest.TestCase):
result = k8s_client.sanitize_for_serialization(result)
expected = \
- {
- 'metadata':
- {
- 'labels': {},
- 'name': 'bar',
- 'namespace': 'baz'
- },
- 'spec':
- {'containers':
- [
- {
- 'args': [],
- 'command': ['airflow'],
- 'env': [{'name': 'test_key', 'value': 'test_value'}],
- 'image': 'foo',
- 'imagePullPolicy': 'Never',
- 'name': 'base',
- 'volumeMounts':
- [
- {
- 'mountPath': '/mnt',
- 'name': 'foo',
- 'readOnly': True, 'subPath': '/'
- }
- ], # noqa
- 'resources':
- {
- 'limits':
- {
- 'cpu': None,
- 'memory': None,
- 'nvidia.com/gpu': '100G',
- 'ephemeral-storage': None
- },
- 'requests':
- {
- 'cpu': '100Mi',
- 'memory': '1G',
- 'ephemeral-storage': None
- }
- }
- }
- ],
- 'hostNetwork': False,
- 'tolerations': [],
- 'volumes': [
- {'name': 'foo'}
- ]
- }
- }
+ {'metadata': {'labels': {}, 'name': 'bar', 'namespace': 'baz'},
+ 'spec': {'affinity': {},
+ 'containers': [{'args': [],
+ 'command': ['airflow'],
+ 'env': [{'name': 'test_key', 'value': 'test_value'},
+ {'name': 'AIRFLOW_SECRET',
+ 'valueFrom': {'secretKeyRef': {'key': 'airflow_config',
+ 'name': 'secret_name'}}}],
+ 'image': 'foo',
+ 'imagePullPolicy': 'Never',
+ 'name': 'base',
+ 'resources': {'limits': {'nvidia.com/gpu': '100G'},
+ 'requests': {'cpu': '100Mi',
+ 'memory': '1G'}},
+ 'volumeMounts': [{'mountPath': '/mnt',
+ 'name': 'foo',
+ 'readOnly': True,
+ 'subPath': '/'},
+ {'mountPath': '/opt/airflow',
+ 'name': 'secretvol' + str(static_uuid),
+ 'readOnly': True}]}],
+ 'hostNetwork': False,
+ 'securityContext': {},
+ 'tolerations': [],
+ 'volumes': [{'name': 'foo'},
+ {'name': 'bar',
+ 'secret': {'secretName': 'volume-secret'}},
+ {'name': 'secretvol' + str(static_uuid),
+ 'secret': {'secretName': 'volume-secret'}}
+ ]}}
self.maxDiff = None
- self.assertEquals(expected, result)
+ self.assertEqual(expected, result)
diff --git a/tests/kubernetes/models/test_volume.py b/tests/kubernetes/models/test_volume.py
new file mode 100644
index 0000000..c1b8e29
--- /dev/null
+++ b/tests/kubernetes/models/test_volume.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import unittest
+
+from kubernetes.client import models as k8s
+
+from airflow.kubernetes.volume import Volume
+
+
+class TestVolume(unittest.TestCase):
+ def test_to_k8s_object(self):
+ volume_config = {
+ 'persistentVolumeClaim':
+ {
+ 'claimName': 'test-volume'
+ }
+ }
+ volume = Volume(name='test-volume', configs=volume_config)
+ expected_volume = k8s.V1Volume(
+ name="test-volume",
+ persistent_volume_claim={
+ "claimName": "test-volume"
+ }
+ )
+ result = volume.to_k8s_client_obj()
+ self.assertEqual(result, expected_volume)
diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py
index d0faf4c..bb714d4 100644
--- a/tests/kubernetes/test_pod_generator.py
+++ b/tests/kubernetes/test_pod_generator.py
@@ -255,6 +255,20 @@ class TestPodGenerator(unittest.TestCase):
"name": "example-kubernetes-test-volume",
},
],
+ "resources": {
+ "requests": {
+ "memory": "256Mi",
+ "cpu": "500m",
+ "ephemeral-storage": "2G",
+ "nvidia.com/gpu": "0"
+ },
+ "limits": {
+ "memory": "512Mi",
+ "cpu": "1000m",
+ "ephemeral-storage": "2G",
+ "nvidia.com/gpu": "0"
+ }
+ }
}
})
result = self.k8s_client.sanitize_for_serialization(result)
@@ -277,6 +291,92 @@ class TestPodGenerator(unittest.TestCase):
'mountPath': '/foo/',
'name': 'example-kubernetes-test-volume'
}],
+ "resources": {
+ "requests": {
+ "memory": "256Mi",
+ "cpu": "500m",
+ "ephemeral-storage": "2G",
+ "nvidia.com/gpu": "0"
+ },
+ "limits": {
+ "memory": "512Mi",
+ "cpu": "1000m",
+ "ephemeral-storage": "2G",
+ "nvidia.com/gpu": "0"
+ }
+ }
+ }],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'volumes': [{
+ 'hostPath': {'path': '/tmp/'},
+ 'name': 'example-kubernetes-test-volume'
+ }],
+ }
+ }, result)
+
+ @mock.patch('uuid.uuid4')
+ def test_from_obj_with_resources_object(self, mock_uuid):
+ mock_uuid.return_value = self.static_uuid
+ result = PodGenerator.from_obj({
+ "KubernetesExecutor": {
+ "annotations": {"test": "annotation"},
+ "volumes": [
+ {
+ "name": "example-kubernetes-test-volume",
+ "hostPath": {"path": "/tmp/"},
+ },
+ ],
+ "volume_mounts": [
+ {
+ "mountPath": "/foo/",
+ "name": "example-kubernetes-test-volume",
+ },
+ ],
+ "resources": {
+ "requests": {
+ "memory": "256Mi",
+ "cpu": "500m",
+ "ephemeral-storage": "2G",
+ "nvidia.com/gpu": "0"
+ },
+ "limits": {
+ "memory": "512Mi",
+ "cpu": "1000m",
+ "ephemeral-storage": "2G",
+ "nvidia.com/gpu": "0"
+ }
+ }
+ }
+ })
+ result = self.k8s_client.sanitize_for_serialization(result)
+
+ self.assertEqual({
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {
+ 'annotations': {'test': 'annotation'},
+ },
+ 'spec': {
+ 'containers': [{
+ 'args': [],
+ 'command': [],
+ 'env': [],
+ 'envFrom': [],
+ 'name': 'base',
+ 'ports': [],
+ 'volumeMounts': [{
+ 'mountPath': '/foo/',
+ 'name': 'example-kubernetes-test-volume'
+ }],
+ 'resources': {'limits': {'cpu': '1000m',
+ 'ephemeral-storage': '2G',
+ 'memory': '512Mi',
+ 'nvidia.com/gpu': '0'},
+ 'requests': {'cpu': '500m',
+ 'ephemeral-storage': '2G',
+ 'memory': '256Mi',
+ 'nvidia.com/gpu': '0'}},
}],
'hostNetwork': False,
'imagePullSecrets': [],
@@ -586,7 +686,7 @@ class TestPodGenerator(unittest.TestCase):
}, sanitized_result)
@mock.patch('uuid.uuid4')
- def test_construct_pod_empty_execuctor_config(self, mock_uuid):
+ def test_construct_pod_empty_executor_config(self, mock_uuid):
mock_uuid.return_value = self.static_uuid
worker_config = k8s.V1Pod(
spec=k8s.V1PodSpec(
@@ -731,6 +831,92 @@ class TestPodGenerator(unittest.TestCase):
}
}, sanitized_result)
+ @mock.patch('uuid.uuid4')
+ def test_construct_pod_with_mutation(self, mock_uuid):
+ mock_uuid.return_value = self.static_uuid
+ worker_config = k8s.V1Pod(
+ metadata=k8s.V1ObjectMeta(
+ name='gets-overridden-by-dynamic-args',
+ annotations={
+ 'should': 'stay'
+ }
+ ),
+ spec=k8s.V1PodSpec(
+ containers=[
+ k8s.V1Container(
+ name='doesnt-override',
+ resources=k8s.V1ResourceRequirements(
+ limits={
+ 'cpu': '1m',
+ 'memory': '1G'
+ }
+ ),
+ security_context=k8s.V1SecurityContext(
+ run_as_user=1
+ )
+ )
+ ]
+ )
+ )
+ executor_config = k8s.V1Pod(
+ spec=k8s.V1PodSpec(
+ containers=[
+ k8s.V1Container(
+ name='doesnt-override-either',
+ resources=k8s.V1ResourceRequirements(
+ limits={
+ 'cpu': '2m',
+ 'memory': '2G'
+ }
+ )
+ )
+ ]
+ )
+ )
+
+ result = PodGenerator.construct_pod(
+ 'dag_id',
+ 'task_id',
+ 'pod_id',
+ 3,
+ 'date',
+ ['command'],
+ executor_config,
+ worker_config,
+ 'namespace',
+ 'uuid',
+ )
+ sanitized_result = self.k8s_client.sanitize_for_serialization(result)
+
+ self.metadata.update({'annotations': {'should': 'stay'}})
+
+ self.assertEqual({
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': self.metadata,
+ 'spec': {
+ 'containers': [{
+ 'args': [],
+ 'command': ['command'],
+ 'env': [],
+ 'envFrom': [],
+ 'name': 'base',
+ 'ports': [],
+ 'resources': {
+ 'limits': {
+ 'cpu': '2m',
+ 'memory': '2G'
+ }
+ },
+ 'volumeMounts': [],
+ 'securityContext': {'runAsUser': 1}
+ }],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'volumes': []
+ }
+ }, sanitized_result)
+
def test_merge_objects_empty(self):
annotations = {'foo1': 'bar1'}
base_obj = k8s.V1ObjectMeta(annotations=annotations)
@@ -901,3 +1087,21 @@ spec:
PodGenerator(image='k')
PodGenerator(pod_template_file='tests/kubernetes/pod.yaml')
PodGenerator(pod=k8s.V1Pod())
+
+ def test_add_custom_label(self):
+ from kubernetes.client import models as k8s
+
+ pod = PodGenerator.construct_pod(
+ namespace="test",
+ worker_uuid="test",
+ pod_id="test",
+ dag_id="test",
+ task_id="test",
+ try_number=1,
+ date="23-07-2020",
+ command="test",
+ kube_executor_config=None,
+ worker_config=k8s.V1Pod(metadata=k8s.V1ObjectMeta(labels={"airflow-test": "airflow-task-pod"},
+ annotations={"my.annotation": "foo"})))
+ self.assertIn("airflow-test", pod.metadata.labels)
+ self.assertIn("my.annotation", pod.metadata.annotations)
diff --git a/tests/kubernetes/test_pod_launcher.py b/tests/kubernetes/test_pod_launcher.py
index 09ba339..c86a6cf 100644
--- a/tests/kubernetes/test_pod_launcher.py
+++ b/tests/kubernetes/test_pod_launcher.py
@@ -16,11 +16,17 @@
# under the License.
import unittest
import mock
+from kubernetes.client import models as k8s
from requests.exceptions import BaseHTTPError
from airflow import AirflowException
-from airflow.kubernetes.pod_launcher import PodLauncher
+from airflow.contrib.kubernetes.pod import Pod
+from airflow.kubernetes.pod import Port
+from airflow.kubernetes.pod_launcher import PodLauncher, _convert_to_airflow_pod
+from airflow.kubernetes.volume import Volume
+from airflow.kubernetes.secret import Secret
+from airflow.kubernetes.volume_mount import VolumeMount
class TestPodLauncher(unittest.TestCase):
@@ -162,3 +168,132 @@ class TestPodLauncher(unittest.TestCase):
self.pod_launcher.read_pod,
mock.sentinel
)
+
+
+class TestPodLauncherHelper(unittest.TestCase):
+ def test_convert_to_airflow_pod(self):
+ input_pod = k8s.V1Pod(
+ metadata=k8s.V1ObjectMeta(
+ name="foo",
+ namespace="bar"
+ ),
+ spec=k8s.V1PodSpec(
+ containers=[
+ k8s.V1Container(
+ name="base",
+ command=["foo"],
+ image="myimage",
+ env=[
+ k8s.V1EnvVar(
+ name="AIRFLOW_SECRET",
+ value_from=k8s.V1EnvVarSource(
+ secret_key_ref=k8s.V1SecretKeySelector(
+ name="ai",
+ key="secret_key"
+ )
+ ))
+ ],
+ ports=[
+ k8s.V1ContainerPort(
+ name="myport",
+ container_port=8080,
+ )
+ ],
+ volume_mounts=[
+ k8s.V1VolumeMount(
+ name="myvolume",
+ mount_path="/tmp/mount",
+ read_only="True"
+ ),
+ k8s.V1VolumeMount(
+ name='airflow-config',
+ mount_path='/config',
+ sub_path='airflow.cfg',
+ read_only=True
+ ),
+ k8s.V1VolumeMount(
+ name="airflow-secret",
+ mount_path="/opt/mount",
+ read_only=True
+ )]
+ )
+ ],
+ security_context=k8s.V1PodSecurityContext(
+ run_as_user=0,
+ fs_group=0,
+ ),
+ volumes=[
+ k8s.V1Volume(
+ name="myvolume"
+ ),
+ k8s.V1Volume(
+ name="airflow-config",
+ config_map=k8s.V1ConfigMap(
+ data="airflow-data"
+ )
+ ),
+ k8s.V1Volume(
+ name="airflow-secret",
+ secret=k8s.V1SecretVolumeSource(
+ secret_name="secret-name",
+
+ )
+ )
+ ]
+ )
+ )
+ result_pod = _convert_to_airflow_pod(input_pod)
+
+ expected = Pod(
+ name="foo",
+ namespace="bar",
+ envs={},
+ cmds=["foo"],
+ image="myimage",
+ ports=[
+ Port(name="myport", container_port=8080)
+ ],
+ volume_mounts=[
+ VolumeMount(
+ name="myvolume",
+ mount_path="/tmp/mount",
+ sub_path=None,
+ read_only="True"
+ ),
+ VolumeMount(
+ name="airflow-config",
+ read_only=True,
+ mount_path="/config",
+ sub_path="airflow.cfg"
+ ),
+ VolumeMount(
+ name="airflow-secret",
+ read_only=True,
+ mount_path="/opt/mount",
+ sub_path=None,
+ )],
+ secrets=[Secret("env", "AIRFLOW_SECRET", "ai", "secret_key"),
+ Secret('volume', '/opt/mount', 'airflow-secret', "secret-name")],
+ security_context={'fsGroup': 0, 'runAsUser': 0},
+ volumes=[Volume(name="myvolume", configs={'name': 'myvolume'}),
+ Volume(name="airflow-config", configs={'configMap': {'data': 'airflow-data'},
+ 'name': 'airflow-config'})]
+ )
+ expected_dict = expected.as_dict()
+ result_dict = result_pod.as_dict()
+ parsed_configs = self.pull_out_volumes(result_dict)
+ result_dict['volumes'] = parsed_configs
+ self.assertDictEqual(expected_dict, result_dict)
+
+ @staticmethod
+ def pull_out_volumes(result_dict):
+ parsed_configs = []
+ for volume in result_dict['volumes']:
+ vol = {'name': volume['name']}
+ confs = {}
+ for k, v in volume['configs'].items():
+ if v and k[0] != '_':
+ confs[k] = v
+ vol['configs'] = confs
+ parsed_configs.append(vol)
+ return parsed_configs
diff --git a/tests/kubernetes/test_pod_launcher_helper.py b/tests/kubernetes/test_pod_launcher_helper.py
deleted file mode 100644
index 761d138..0000000
--- a/tests/kubernetes/test_pod_launcher_helper.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import unittest
-
-from airflow.kubernetes.pod import Port
-from airflow.kubernetes.volume_mount import VolumeMount
-from airflow.kubernetes.volume import Volume
-from airflow.kubernetes.pod_launcher_helper import convert_to_airflow_pod
-from airflow.contrib.kubernetes.pod import Pod
-import kubernetes.client.models as k8s
-
-
-class TestPodLauncherHelper(unittest.TestCase):
- def test_convert_to_airflow_pod(self):
- input_pod = k8s.V1Pod(
- metadata=k8s.V1ObjectMeta(
- name="foo",
- namespace="bar"
- ),
- spec=k8s.V1PodSpec(
- containers=[
- k8s.V1Container(
- name="base",
- command="foo",
- image="myimage",
- ports=[
- k8s.V1ContainerPort(
- name="myport",
- container_port=8080,
- )
- ],
- volume_mounts=[k8s.V1VolumeMount(
- name="mymount",
- mount_path="/tmp/mount",
- read_only="True"
- )]
- )
- ],
- volumes=[
- k8s.V1Volume(
- name="myvolume"
- )
- ]
- )
- )
- result_pod = convert_to_airflow_pod(input_pod)
-
- expected = Pod(
- name="foo",
- namespace="bar",
- envs={},
- cmds=[],
- image="myimage",
- ports=[
- Port(name="myport", container_port=8080)
- ],
- volume_mounts=[VolumeMount(
- name="mymount",
- mount_path="/tmp/mount",
- sub_path=None,
- read_only="True"
- )],
- volumes=[Volume(name="myvolume", configs={'name': 'myvolume'})]
- )
- expected_dict = expected.as_dict()
- result_dict = result_pod.as_dict()
- parsed_configs = self.pull_out_volumes(result_dict)
- result_dict['volumes'] = parsed_configs
- self.maxDiff = None
-
- self.assertDictEqual(expected_dict, result_dict)
-
- @staticmethod
- def pull_out_volumes(result_dict):
- parsed_configs = []
- for volume in result_dict['volumes']:
- vol = {'name': volume['name']}
- confs = {}
- for k, v in volume['configs'].items():
- if v and k[0] != '_':
- confs[k] = v
- vol['configs'] = confs
- parsed_configs.append(vol)
- return parsed_configs
diff --git a/tests/kubernetes/test_worker_configuration.py b/tests/kubernetes/test_worker_configuration.py
index a94a112..0273ae8 100644
--- a/tests/kubernetes/test_worker_configuration.py
+++ b/tests/kubernetes/test_worker_configuration.py
@@ -173,6 +173,13 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
self.assertNotIn('AIRFLOW__CORE__DAGS_FOLDER', env)
+ @conf_vars({
+ ('kubernetes', 'airflow_configmap'): 'airflow-configmap'})
+ def test_worker_adds_config(self):
+ worker_config = WorkerConfiguration(self.kube_config)
+ volumes = worker_config._get_volumes()
+ print(volumes)
+
def test_worker_environment_when_dags_folder_specified(self):
self.kube_config.airflow_configmap = 'airflow-configmap'
self.kube_config.git_dags_folder_mount_point = ''
diff --git a/tests/test_local_settings.py b/tests/test_local_settings.py
deleted file mode 100644
index 0e45ad8..0000000
--- a/tests/test_local_settings.py
+++ /dev/null
@@ -1,269 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import os
-import sys
-import tempfile
-import unittest
-from airflow.kubernetes import pod_generator
-from tests.compat import MagicMock, Mock, call, patch
-
-
-SETTINGS_FILE_POLICY = """
-def test_policy(task_instance):
- task_instance.run_as_user = "myself"
-"""
-
-SETTINGS_FILE_POLICY_WITH_DUNDER_ALL = """
-__all__ = ["test_policy"]
-
-def test_policy(task_instance):
- task_instance.run_as_user = "myself"
-
-def not_policy():
- print("This shouldn't be imported")
-"""
-
-SETTINGS_FILE_POD_MUTATION_HOOK = """
-from airflow.kubernetes.volume import Volume
-from airflow.kubernetes.pod import Port, Resources
-
-def pod_mutation_hook(pod):
- pod.namespace = 'airflow-tests'
- pod.image = 'my_image'
- pod.volumes.append(Volume(name="bar", configs={}))
- pod.ports = [Port(container_port=8080)]
- pod.resources = Resources(
- request_memory="2G",
- request_cpu="200Mi",
- limit_gpu="200G"
- )
-
-"""
-
-SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD = """
-def pod_mutation_hook(pod):
- pod.spec.containers[0].image = "test-image"
-
-"""
-
-
-class SettingsContext:
- def __init__(self, content, module_name):
- self.content = content
- self.settings_root = tempfile.mkdtemp()
- filename = "{}.py".format(module_name)
- self.settings_file = os.path.join(self.settings_root, filename)
-
- def __enter__(self):
- with open(self.settings_file, 'w') as handle:
- handle.writelines(self.content)
- sys.path.append(self.settings_root)
- return self.settings_file
-
- def __exit__(self, *exc_info):
- sys.path.remove(self.settings_root)
-
-
-class LocalSettingsTest(unittest.TestCase):
- # Make sure that the configure_logging is not cached
- def setUp(self):
- self.old_modules = dict(sys.modules)
-
- def tearDown(self):
- # Remove any new modules imported during the test run. This lets us
- # import the same source files for more than one test.
- for mod in [m for m in sys.modules if m not in self.old_modules]:
- del sys.modules[mod]
-
- @patch("airflow.settings.import_local_settings")
- @patch("airflow.settings.prepare_syspath")
- def test_initialize_order(self, prepare_syspath, import_local_settings):
- """
- Tests that import_local_settings is called after prepare_classpath
- """
- mock = Mock()
- mock.attach_mock(prepare_syspath, "prepare_syspath")
- mock.attach_mock(import_local_settings, "import_local_settings")
-
- import airflow.settings
- airflow.settings.initialize()
-
- mock.assert_has_calls([call.prepare_syspath(), call.import_local_settings()])
-
- def test_import_with_dunder_all_not_specified(self):
- """
- Tests that if __all__ is specified in airflow_local_settings,
- only module attributes specified within are imported.
- """
- with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"):
- from airflow import settings
- settings.import_local_settings() # pylint: ignore
-
- with self.assertRaises(AttributeError):
- settings.not_policy()
-
- def test_import_with_dunder_all(self):
- """
- Tests that if __all__ is specified in airflow_local_settings,
- only module attributes specified within are imported.
- """
- with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"):
- from airflow import settings
- settings.import_local_settings() # pylint: ignore
-
- task_instance = MagicMock()
- settings.test_policy(task_instance)
-
- assert task_instance.run_as_user == "myself"
-
- @patch("airflow.settings.log.debug")
- def test_import_local_settings_without_syspath(self, log_mock):
- """
- Tests that an ImportError is raised in import_local_settings
- if there is no airflow_local_settings module on the syspath.
- """
- from airflow import settings
- settings.import_local_settings()
- log_mock.assert_called_with("Failed to import airflow_local_settings.", exc_info=True)
-
- def test_policy_function(self):
- """
- Tests that task instances are mutated by the policy
- function in airflow_local_settings.
- """
- with SettingsContext(SETTINGS_FILE_POLICY, "airflow_local_settings"):
- from airflow import settings
- settings.import_local_settings() # pylint: ignore
-
- task_instance = MagicMock()
- settings.test_policy(task_instance)
-
- assert task_instance.run_as_user == "myself"
-
- def test_pod_mutation_hook(self):
- """
- Tests that pods are mutated by the pod_mutation_hook
- function in airflow_local_settings.
- """
- with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"):
- from airflow import settings
- settings.import_local_settings() # pylint: ignore
-
- pod = MagicMock()
- pod.volumes = []
- settings.pod_mutation_hook(pod)
-
- assert pod.namespace == 'airflow-tests'
- self.assertEqual(pod.volumes[0].name, "bar")
-
- def test_pod_mutation_to_k8s_pod(self):
- with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"):
- from airflow import settings
- settings.import_local_settings() # pylint: ignore
- from airflow.kubernetes.pod_launcher import PodLauncher
-
- self.mock_kube_client = Mock()
- self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
- pod = pod_generator.PodGenerator(
- image="foo",
- name="bar",
- namespace="baz",
- image_pull_policy="Never",
- cmds=["foo"],
- volume_mounts=[
- {"name": "foo", "mount_path": "/mnt", "sub_path": "/", "read_only": "True"}
- ],
- volumes=[{"name": "foo"}]
- ).gen_pod()
-
- self.assertEqual(pod.metadata.namespace, "baz")
- self.assertEqual(pod.spec.containers[0].image, "foo")
- self.assertEqual(pod.spec.volumes, [{'name': 'foo'}])
- self.assertEqual(pod.spec.containers[0].ports, [])
- self.assertEqual(pod.spec.containers[0].resources, None)
-
- pod = self.pod_launcher._mutate_pod_backcompat(pod)
-
- self.assertEqual(pod.metadata.namespace, "airflow-tests")
- self.assertEqual(pod.spec.containers[0].image, "my_image")
- self.assertEqual(pod.spec.volumes, [{'name': 'foo'}, {'name': 'bar'}])
- self.maxDiff = None
- self.assertEqual(
- pod.spec.containers[0].ports[0].to_dict(),
- {
- "container_port": 8080,
- "host_ip": None,
- "host_port": None,
- "name": None,
- "protocol": None
- }
- )
- self.assertEqual(
- pod.spec.containers[0].resources.to_dict(),
- {
- 'limits': {
- 'cpu': None,
- 'memory': None,
- 'ephemeral-storage': None,
- 'nvidia.com/gpu': '200G'},
- 'requests': {'cpu': '200Mi', 'ephemeral-storage': None, 'memory': '2G'}
- }
- )
-
- def test_pod_mutation_v1_pod(self):
- with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD, "airflow_local_settings"):
- from airflow import settings
- settings.import_local_settings() # pylint: ignore
- from airflow.kubernetes.pod_launcher import PodLauncher
-
- self.mock_kube_client = Mock()
- self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
- pod = pod_generator.PodGenerator(
- image="myimage",
- cmds=["foo"],
- volume_mounts={
- "name": "foo", "mount_path": "/mnt", "sub_path": "/", "read_only": "True"
- },
- volumes=[{"name": "foo"}]
- ).gen_pod()
-
- self.assertEqual(pod.spec.containers[0].image, "myimage")
- pod = self.pod_launcher._mutate_pod_backcompat(pod)
- self.assertEqual(pod.spec.containers[0].image, "test-image")
-
-
-class TestStatsWithAllowList(unittest.TestCase):
-
- def setUp(self):
- from airflow.settings import SafeStatsdLogger, AllowListValidator
- self.statsd_client = Mock()
- self.stats = SafeStatsdLogger(self.statsd_client, AllowListValidator("stats_one, stats_two"))
-
- def test_increment_counter_with_allowed_key(self):
- self.stats.incr('stats_one')
- self.statsd_client.incr.assert_called_once_with('stats_one', 1, 1)
-
- def test_increment_counter_with_allowed_prefix(self):
- self.stats.incr('stats_two.bla')
- self.statsd_client.incr.assert_called_once_with('stats_two.bla', 1, 1)
-
- def test_not_increment_counter_if_not_allowed(self):
- self.stats.incr('stats_three')
- self.statsd_client.assert_not_called()
diff --git a/tests/test_local_settings/__init__.py b/tests/test_local_settings/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/test_local_settings/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/test_local_settings/test_local_settings.py b/tests/test_local_settings/test_local_settings.py
new file mode 100644
index 0000000..ece813d
--- /dev/null
+++ b/tests/test_local_settings/test_local_settings.py
@@ -0,0 +1,441 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import os
+import sys
+import tempfile
+import unittest
+from airflow.kubernetes import pod_generator
+from kubernetes.client import ApiClient
+import kubernetes.client.models as k8s
+from tests.compat import MagicMock, Mock, mock, call, patch
+
+api_client = ApiClient()
+
+SETTINGS_FILE_POLICY = """
+def test_policy(task_instance):
+ task_instance.run_as_user = "myself"
+"""
+
+SETTINGS_FILE_POLICY_WITH_DUNDER_ALL = """
+__all__ = ["test_policy"]
+
+def test_policy(task_instance):
+ task_instance.run_as_user = "myself"
+
+def not_policy():
+ print("This shouldn't be imported")
+"""
+
+SETTINGS_FILE_POD_MUTATION_HOOK = """
+from airflow.kubernetes.volume import Volume
+from airflow.kubernetes.pod import Port, Resources
+
+def pod_mutation_hook(pod):
+ pod.namespace = 'airflow-tests'
+ pod.image = 'my_image'
+ pod.volumes.append(Volume(name="bar", configs={}))
+ pod.ports = [Port(container_port=8080), {"containerPort": 8081}]
+ pod.resources = Resources(
+ request_memory="2G",
+ request_cpu="200Mi",
+ limit_gpu="200G"
+ )
+
+ secret_volume = {
+ "name": "airflow-secrets-mount",
+ "secret": {
+ "secretName": "airflow-test-secrets"
+ }
+ }
+ secret_volume_mount = {
+ "name": "airflow-secrets-mount",
+ "readOnly": True,
+ "mountPath": "/opt/airflow/secrets/"
+ }
+
+ pod.volumes.append(secret_volume)
+ pod.volume_mounts.append(secret_volume_mount)
+
+ pod.labels.update({"test_label": "test_value"})
+ pod.envs.update({"TEST_USER": "ADMIN"})
+
+ pod.tolerations += [
+ {"key": "dynamic-pods", "operator": "Equal", "value": "true", "effect": "NoSchedule"}
+ ]
+ pod.affinity.update(
+ {"nodeAffinity":
+ {"requiredDuringSchedulingIgnoredDuringExecution":
+ {"nodeSelectorTerms":
+ [{
+ "matchExpressions": [
+ {"key": "test/dynamic-pods", "operator": "In", "values": ["true"]}
+ ]
+ }]
+ }
+ }
+ }
+ )
+
+ if 'fsGroup' in pod.security_context and pod.security_context['fsGroup'] == 0 :
+ del pod.security_context['fsGroup']
+ if 'runAsUser' in pod.security_context and pod.security_context['runAsUser'] == 0 :
+ del pod.security_context['runAsUser']
+
+ if pod.args and pod.args[0] == "/bin/sh":
+ pod.args = ['/bin/sh', '-c', 'touch /tmp/healthy2']
+
+"""
+
+SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD = """
+def pod_mutation_hook(pod):
+ from kubernetes.client import models as k8s
+ secret_volume = {
+ "name": "airflow-secrets-mount",
+ "secret": {
+ "secretName": "airflow-test-secrets"
+ }
+ }
+ secret_volume_mount = {
+ "name": "airflow-secrets-mount",
+ "readOnly": True,
+ "mountPath": "/opt/airflow/secrets/"
+ }
+ base_container = pod.spec.containers[0]
+ base_container.image = "test-image"
+ base_container.volume_mounts.append(secret_volume_mount)
+ base_container.env.extend([{'name': 'TEST_USER', 'value': 'ADMIN'}])
+ base_container.ports.extend([{'containerPort': 8080}, k8s.V1ContainerPort(container_port=8081)])
+
+ pod.spec.volumes.append(secret_volume)
+ pod.metadata.namespace = 'airflow-tests'
+
+"""
+
+
+class SettingsContext:
+ def __init__(self, content, module_name):
+ self.content = content
+ self.settings_root = tempfile.mkdtemp()
+ filename = "{}.py".format(module_name)
+ self.settings_file = os.path.join(self.settings_root, filename)
+
+ def __enter__(self):
+ with open(self.settings_file, 'w') as handle:
+ handle.writelines(self.content)
+ sys.path.append(self.settings_root)
+ return self.settings_file
+
+ def __exit__(self, *exc_info):
+ sys.path.remove(self.settings_root)
+
+
+class LocalSettingsTest(unittest.TestCase):
+ # Make sure that the configure_logging is not cached
+ def setUp(self):
+ self.old_modules = dict(sys.modules)
+ self.maxDiff = None
+
+ def tearDown(self):
+ # Remove any new modules imported during the test run. This lets us
+ # import the same source files for more than one test.
+ for mod in [m for m in sys.modules if m not in self.old_modules]:
+ del sys.modules[mod]
+
+ @patch("airflow.settings.import_local_settings")
+ @patch("airflow.settings.prepare_syspath")
+ def test_initialize_order(self, prepare_syspath, import_local_settings):
+ """
+ Tests that import_local_settings is called after prepare_classpath
+ """
+ mock = Mock()
+ mock.attach_mock(prepare_syspath, "prepare_syspath")
+ mock.attach_mock(import_local_settings, "import_local_settings")
+
+ import airflow.settings
+ airflow.settings.initialize()
+
+ mock.assert_has_calls([call.prepare_syspath(), call.import_local_settings()])
+
+ def test_import_with_dunder_all_not_specified(self):
+ """
+ Tests that if __all__ is specified in airflow_local_settings,
+ only module attributes specified within are imported.
+ """
+ with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"):
+ from airflow import settings
+ settings.import_local_settings() # pylint: ignore
+
+ with self.assertRaises(AttributeError):
+ settings.not_policy()
+
+ def test_import_with_dunder_all(self):
+ """
+ Tests that if __all__ is specified in airflow_local_settings,
+ only module attributes specified within are imported.
+ """
+ with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"):
+ from airflow import settings
+ settings.import_local_settings() # pylint: ignore
+
+ task_instance = MagicMock()
+ settings.test_policy(task_instance)
+
+ assert task_instance.run_as_user == "myself"
+
+ @patch("airflow.settings.log.debug")
+ def test_import_local_settings_without_syspath(self, log_mock):
+ """
+ Tests that an ImportError is raised in import_local_settings
+ if there is no airflow_local_settings module on the syspath.
+ """
+ from airflow import settings
+ settings.import_local_settings()
+ log_mock.assert_called_with("Failed to import airflow_local_settings.", exc_info=True)
+
+ def test_policy_function(self):
+ """
+ Tests that task instances are mutated by the policy
+ function in airflow_local_settings.
+ """
+ with SettingsContext(SETTINGS_FILE_POLICY, "airflow_local_settings"):
+ from airflow import settings
+ settings.import_local_settings() # pylint: ignore
+
+ task_instance = MagicMock()
+ settings.test_policy(task_instance)
+
+ assert task_instance.run_as_user == "myself"
+
+ def test_pod_mutation_hook(self):
+ """
+ Tests that pods are mutated by the pod_mutation_hook
+ function in airflow_local_settings.
+ """
+ with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"):
+ from airflow import settings
+ settings.import_local_settings() # pylint: ignore
+
+ pod = MagicMock()
+ pod.volumes = []
+ settings.pod_mutation_hook(pod)
+
+ assert pod.namespace == 'airflow-tests'
+ self.assertEqual(pod.volumes[0].name, "bar")
+
+ def test_pod_mutation_to_k8s_pod(self):
+ with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"):
+ from airflow import settings
+ settings.import_local_settings() # pylint: ignore
+ from airflow.kubernetes.pod_launcher import PodLauncher
+
+ self.mock_kube_client = Mock()
+ self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
+ pod = pod_generator.PodGenerator(
+ image="foo",
+ name="bar",
+ namespace="baz",
+ image_pull_policy="Never",
+ cmds=["foo"],
+ args=["/bin/sh", "-c", "touch /tmp/healthy"],
+ tolerations=[
+ {'effect': 'NoSchedule',
+ 'key': 'static-pods',
+ 'operator': 'Equal',
+ 'value': 'true'}
+ ],
+ volume_mounts=[
+ {"name": "foo", "mountPath": "/mnt", "subPath": "/", "readOnly": True}
+ ],
+ security_context=k8s.V1PodSecurityContext(fs_group=0, run_as_user=1),
+ volumes=[k8s.V1Volume(name="foo")]
+ ).gen_pod()
+
+ sanitized_pod_pre_mutation = api_client.sanitize_for_serialization(pod)
+ self.assertEqual(
+ sanitized_pod_pre_mutation,
+ {'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {'name': mock.ANY,
+ 'namespace': 'baz'},
+ 'spec': {'containers': [{'args': ['/bin/sh', '-c', 'touch /tmp/healthy'],
+ 'command': ['foo'],
+ 'env': [],
+ 'envFrom': [],
+ 'image': 'foo',
+ 'imagePullPolicy': 'Never',
+ 'name': 'base',
+ 'ports': [],
+ 'volumeMounts': [{'mountPath': '/mnt',
+ 'name': 'foo',
+ 'readOnly': True,
+ 'subPath': '/'}]}],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'tolerations': [{'effect': 'NoSchedule',
+ 'key': 'static-pods',
+ 'operator': 'Equal',
+ 'value': 'true'}],
+ 'volumes': [{'name': 'foo'}],
+ 'securityContext': {'fsGroup': 0, 'runAsUser': 1}}},
+ )
+
+ # Apply Pod Mutation Hook
+ pod = self.pod_launcher._mutate_pod_backcompat(pod)
+
+ sanitized_pod_post_mutation = api_client.sanitize_for_serialization(pod)
+
+ self.assertEqual(
+ sanitized_pod_post_mutation,
+ {"apiVersion": "v1",
+ "kind": "Pod",
+ 'metadata': {'labels': {'test_label': 'test_value'},
+ 'name': mock.ANY,
+ 'namespace': 'airflow-tests'},
+ 'spec': {'affinity': {'nodeAffinity': {'requiredDuringSchedulingIgnoredDuringExecution': {
+ 'nodeSelectorTerms': [{'matchExpressions': [{'key': 'test/dynamic-pods',
+ 'operator': 'In',
+ 'values': ['true']}]}]}}},
+ 'containers': [{'args': ['/bin/sh', '-c', 'touch /tmp/healthy2'],
+ 'command': ['foo'],
+ 'env': [{'name': 'TEST_USER', 'value': 'ADMIN'}],
+ 'image': 'my_image',
+ 'imagePullPolicy': 'Never',
+ 'name': 'base',
+ 'ports': [{'containerPort': 8080},
+ {'containerPort': 8081}],
+ 'resources': {'limits': {'nvidia.com/gpu': '200G'},
+ 'requests': {'cpu': '200Mi',
+ 'memory': '2G'}},
+ 'volumeMounts': [{'mountPath': '/opt/airflow/secrets/',
+ 'name': 'airflow-secrets-mount',
+ 'readOnly': True},
+ {'mountPath': '/mnt',
+ 'name': 'foo',
+ 'readOnly': True,
+ 'subPath': '/'}
+ ]}],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'tolerations': [{'effect': 'NoSchedule',
+ 'key': 'static-pods',
+ 'operator': 'Equal',
+ 'value': 'true'},
+ {'effect': 'NoSchedule',
+ 'key': 'dynamic-pods',
+ 'operator': 'Equal',
+ 'value': 'true'}],
+ 'volumes': [{'name': 'airflow-secrets-mount',
+ 'secret': {'secretName': 'airflow-test-secrets'}},
+ {'name': 'bar'},
+ {'name': 'foo'},
+ ],
+ 'securityContext': {'runAsUser': 1}}}
+ )
+
+ def test_pod_mutation_v1_pod(self):
+ with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD, "airflow_local_settings"):
+ from airflow import settings
+ settings.import_local_settings() # pylint: ignore
+ from airflow.kubernetes.pod_launcher import PodLauncher
+
+ self.mock_kube_client = Mock()
+ self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
+ pod = pod_generator.PodGenerator(
+ image="myimage",
+ cmds=["foo"],
+ namespace="baz",
+ volume_mounts=[
+ {"name": "foo", "mountPath": "/mnt", "subPath": "/", "readOnly": True}
+ ],
+ volumes=[{"name": "foo"}]
+ ).gen_pod()
+
+ sanitized_pod_pre_mutation = api_client.sanitize_for_serialization(pod)
+
+ self.assertEqual(
+ sanitized_pod_pre_mutation,
+ {'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {'namespace': 'baz'},
+ 'spec': {'containers': [{'args': [],
+ 'command': ['foo'],
+ 'env': [],
+ 'envFrom': [],
+ 'image': 'myimage',
+ 'name': 'base',
+ 'ports': [],
+ 'volumeMounts': [{'mountPath': '/mnt',
+ 'name': 'foo',
+ 'readOnly': True,
+ 'subPath': '/'}]}],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'volumes': [{'name': 'foo'}]}}
+ )
+
+ # Apply Pod Mutation Hook
+ pod = self.pod_launcher._mutate_pod_backcompat(pod)
+
+ sanitized_pod_post_mutation = api_client.sanitize_for_serialization(pod)
+ self.assertEqual(
+ sanitized_pod_post_mutation,
+ {'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {'namespace': 'airflow-tests'},
+ 'spec': {'containers': [{'args': [],
+ 'command': ['foo'],
+ 'env': [{'name': 'TEST_USER', 'value': 'ADMIN'}],
+ 'envFrom': [],
+ 'image': 'test-image',
+ 'name': 'base',
+ 'ports': [{'containerPort': 8080}, {'containerPort': 8081}],
+ 'volumeMounts': [{'mountPath': '/mnt',
+ 'name': 'foo',
+ 'readOnly': True,
+ 'subPath': '/'},
+ {'mountPath': '/opt/airflow/secrets/',
+ 'name': 'airflow-secrets-mount',
+ 'readOnly': True}]}],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'volumes': [{'name': 'foo'},
+ {'name': 'airflow-secrets-mount',
+ 'secret': {'secretName': 'airflow-test-secrets'}}]}}
+ )
+
+
+class TestStatsWithAllowList(unittest.TestCase):
+
+ def setUp(self):
+ from airflow.settings import SafeStatsdLogger, AllowListValidator
+ self.statsd_client = Mock()
+ self.stats = SafeStatsdLogger(self.statsd_client, AllowListValidator("stats_one, stats_two"))
+
+ def test_increment_counter_with_allowed_key(self):
+ self.stats.incr('stats_one')
+ self.statsd_client.incr.assert_called_once_with('stats_one', 1, 1)
+
+ def test_increment_counter_with_allowed_prefix(self):
+ self.stats.incr('stats_two.bla')
+ self.statsd_client.incr.assert_called_once_with('stats_two.bla', 1, 1)
+
+ def test_not_increment_counter_if_not_allowed(self):
+ self.stats.incr('stats_three')
+ self.statsd_client.assert_not_called()