You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/11 22:35:01 UTC
[airflow] 20/32: Fixes PodMutationHook for backwards compatibility
(#9903)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit bcd02ddb81a07026dcbbc5e5a4dc669a6483b59b
Author: Daniel Imberman <da...@gmail.com>
AuthorDate: Thu Jul 30 11:40:23 2020 -0700
Fixes PodMutationHook for backwards compatibility (#9903)
Co-authored-by: Daniel Imberman <da...@astronomer.io>
Co-authored-by: Kaxil Naik <ka...@gmail.com>
---
airflow/kubernetes/k8s_model.py | 16 +++
airflow/kubernetes/pod.py | 33 ++++--
airflow/kubernetes/pod_launcher.py | 26 +++-
airflow/kubernetes/pod_launcher_helper.py | 96 +++++++++++++++
airflow/kubernetes/volume_mount.py | 1 +
airflow/kubernetes_deprecated/__init__.py | 16 +++
airflow/kubernetes_deprecated/pod.py | 171 +++++++++++++++++++++++++++
docs/conf.py | 1 +
tests/kubernetes/models/test_pod.py | 81 +++++++++++++
tests/kubernetes/test_pod_launcher_helper.py | 97 +++++++++++++++
tests/test_local_settings.py | 96 +++++++++++++++
11 files changed, 619 insertions(+), 15 deletions(-)
diff --git a/airflow/kubernetes/k8s_model.py b/airflow/kubernetes/k8s_model.py
index 3fd2f9e..e10a946 100644
--- a/airflow/kubernetes/k8s_model.py
+++ b/airflow/kubernetes/k8s_model.py
@@ -29,6 +29,7 @@ else:
class K8SModel(ABC):
+
"""
These Airflow Kubernetes models are here for backwards compatibility
reasons only. Ideally clients should use the kubernetes api
@@ -39,6 +40,7 @@ class K8SModel(ABC):
can be avoided. All of these models implement the
`attach_to_pod` method so that they integrate with the kubernetes client.
"""
+
@abc.abstractmethod
def attach_to_pod(self, pod):
"""
@@ -47,9 +49,23 @@ class K8SModel(ABC):
:return: The pod with the object attached
"""
+ def as_dict(self):
+ res = {}
+ if hasattr(self, "__slots__"):
+ for s in self.__slots__:
+ if hasattr(self, s):
+ res[s] = getattr(self, s)
+ if hasattr(self, "__dict__"):
+ res_dict = self.__dict__.copy()
+ res_dict.update(res)
+ return res_dict
+ return res
+
def append_to_pod(pod, k8s_objects):
"""
+ Attach Kubernetes objects to the given POD
+
:param pod: A pod to attach a list of Kubernetes objects to
:type pod: kubernetes.client.models.V1Pod
:param k8s_objects: a potential None list of K8SModels
diff --git a/airflow/kubernetes/pod.py b/airflow/kubernetes/pod.py
index 0b332c2..9e455af 100644
--- a/airflow/kubernetes/pod.py
+++ b/airflow/kubernetes/pod.py
@@ -26,7 +26,13 @@ from airflow.kubernetes.k8s_model import K8SModel
class Resources(K8SModel):
- __slots__ = ('request_memory', 'request_cpu', 'limit_memory', 'limit_cpu', 'limit_gpu')
+ __slots__ = ('request_memory',
+ 'request_cpu',
+ 'limit_memory',
+ 'limit_cpu',
+ 'limit_gpu',
+ 'request_ephemeral_storage',
+ 'limit_ephemeral_storage')
"""
:param request_memory: requested memory
@@ -44,15 +50,17 @@ class Resources(K8SModel):
:param limit_ephemeral_storage: Limit for ephemeral storage
:type limit_ephemeral_storage: float | str
"""
+
def __init__(
- self,
- request_memory=None,
- request_cpu=None,
- request_ephemeral_storage=None,
- limit_memory=None,
- limit_cpu=None,
- limit_gpu=None,
- limit_ephemeral_storage=None):
+ self,
+ request_memory=None,
+ request_cpu=None,
+ request_ephemeral_storage=None,
+ limit_memory=None,
+ limit_cpu=None,
+ limit_gpu=None,
+ limit_ephemeral_storage=None
+ ):
self.request_memory = request_memory
self.request_cpu = request_cpu
self.request_ephemeral_storage = request_ephemeral_storage
@@ -104,9 +112,10 @@ class Port(K8SModel):
__slots__ = ('name', 'container_port')
def __init__(
- self,
- name=None,
- container_port=None):
+ self,
+ name=None,
+ container_port=None
+ ):
"""Creates port"""
self.name = name
self.container_port = container_port
diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py
index d27a647..05df204 100644
--- a/airflow/kubernetes/pod_launcher.py
+++ b/airflow/kubernetes/pod_launcher.py
@@ -26,10 +26,12 @@ from kubernetes.stream import stream as kubernetes_stream
from requests.exceptions import BaseHTTPError
from airflow import AirflowException
+from airflow.kubernetes.pod_launcher_helper import convert_to_airflow_pod
from airflow.kubernetes.pod_generator import PodDefaults
-from airflow.settings import pod_mutation_hook
+from airflow import settings
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
+import kubernetes.client.models as k8s # noqa
from .kube_client import get_kube_client
@@ -62,8 +64,12 @@ class PodLauncher(LoggingMixin):
self.extract_xcom = extract_xcom
def run_pod_async(self, pod, **kwargs):
- """Runs POD asynchronously"""
- pod_mutation_hook(pod)
+ """Runs POD asynchronously
+
+ :param pod: Pod to run
+ :type pod: k8s.V1Pod
+ """
+ pod = self._mutate_pod_backcompat(pod)
sanitized_pod = self._client.api_client.sanitize_for_serialization(pod)
json_pod = json.dumps(sanitized_pod, indent=2)
@@ -79,6 +85,20 @@ class PodLauncher(LoggingMixin):
raise e
return resp
+ @staticmethod
+ def _mutate_pod_backcompat(pod):
+ """Backwards compatible Pod Mutation Hook"""
+ try:
+ settings.pod_mutation_hook(pod)
+ # attempts to run pod_mutation_hook using k8s.V1Pod, if this
+ # fails we attempt to run by converting pod to Old Pod
+ except AttributeError:
+ dummy_pod = convert_to_airflow_pod(pod)
+ settings.pod_mutation_hook(dummy_pod)
+ dummy_pod = dummy_pod.to_v1_kubernetes_pod()
+ return dummy_pod
+ return pod
+
def delete_pod(self, pod):
"""Deletes POD"""
try:
diff --git a/airflow/kubernetes/pod_launcher_helper.py b/airflow/kubernetes/pod_launcher_helper.py
new file mode 100644
index 0000000..d8b2698
--- /dev/null
+++ b/airflow/kubernetes/pod_launcher_helper.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import List, Union
+
+import kubernetes.client.models as k8s # noqa
+
+from airflow.kubernetes.volume import Volume
+from airflow.kubernetes.volume_mount import VolumeMount
+from airflow.kubernetes.pod import Port
+from airflow.kubernetes_deprecated.pod import Pod
+
+
+def convert_to_airflow_pod(pod):
+ base_container = pod.spec.containers[0] # type: k8s.V1Container
+
+ dummy_pod = Pod(
+ image=base_container.image,
+ envs=_extract_env_vars(base_container.env),
+ volumes=_extract_volumes(pod.spec.volumes),
+ volume_mounts=_extract_volume_mounts(base_container.volume_mounts),
+ labels=pod.metadata.labels,
+ name=pod.metadata.name,
+ namespace=pod.metadata.namespace,
+ image_pull_policy=base_container.image_pull_policy or 'IfNotPresent',
+ cmds=[],
+ ports=_extract_ports(base_container.ports)
+ )
+ return dummy_pod
+
+
+def _extract_env_vars(env_vars):
+ """
+
+ :param env_vars:
+ :type env_vars: list
+ :return: result
+ :rtype: dict
+ """
+ result = {}
+ env_vars = env_vars or [] # type: List[Union[k8s.V1EnvVar, dict]]
+ for env_var in env_vars:
+ if isinstance(env_var, k8s.V1EnvVar):
+ env_var.to_dict()
+ result[env_var.get("name")] = env_var.get("value")
+ return result
+
+
+def _extract_volumes(volumes):
+ result = []
+ volumes = volumes or [] # type: List[Union[k8s.V1Volume, dict]]
+ for volume in volumes:
+ if isinstance(volume, k8s.V1Volume):
+ volume = volume.to_dict()
+ result.append(Volume(name=volume.get("name"), configs=volume))
+ return result
+
+
+def _extract_volume_mounts(volume_mounts):
+ result = []
+ volume_mounts = volume_mounts or [] # type: List[Union[k8s.V1VolumeMount, dict]]
+ for volume_mount in volume_mounts:
+ if isinstance(volume_mount, k8s.V1VolumeMount):
+ volume_mount = volume_mount.to_dict()
+ result.append(
+ VolumeMount(
+ name=volume_mount.get("name"),
+ mount_path=volume_mount.get("mount_path"),
+ sub_path=volume_mount.get("sub_path"),
+ read_only=volume_mount.get("read_only"))
+ )
+
+ return result
+
+
+def _extract_ports(ports):
+ result = []
+ ports = ports or [] # type: List[Union[k8s.V1ContainerPort, dict]]
+ for port in ports:
+ if isinstance(port, k8s.V1ContainerPort):
+ port = port.to_dict()
+ result.append(Port(name=port.get("name"), container_port=port.get("container_port")))
+ return result
diff --git a/airflow/kubernetes/volume_mount.py b/airflow/kubernetes/volume_mount.py
index 0dbca5f..ab87ba9 100644
--- a/airflow/kubernetes/volume_mount.py
+++ b/airflow/kubernetes/volume_mount.py
@@ -24,6 +24,7 @@ from airflow.kubernetes.k8s_model import K8SModel
class VolumeMount(K8SModel):
+ __slots__ = ('name', 'mount_path', 'sub_path', 'read_only')
"""
Initialize a Kubernetes Volume Mount. Used to mount pod level volumes to
running container.
diff --git a/airflow/kubernetes_deprecated/__init__.py b/airflow/kubernetes_deprecated/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/airflow/kubernetes_deprecated/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/kubernetes_deprecated/pod.py b/airflow/kubernetes_deprecated/pod.py
new file mode 100644
index 0000000..22a8c12
--- /dev/null
+++ b/airflow/kubernetes_deprecated/pod.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import kubernetes.client.models as k8s
+from airflow.kubernetes.pod import Resources
+
+
+class Pod(object):
+ """
+ Represents a kubernetes pod and manages execution of a single pod.
+
+ :param image: The docker image
+ :type image: str
+ :param envs: A dict containing the environment variables
+ :type envs: dict
+ :param cmds: The command to be run on the pod
+ :type cmds: list[str]
+ :param secrets: Secrets to be launched to the pod
+ :type secrets: list[airflow.contrib.kubernetes.secret.Secret]
+ :param result: The result that will be returned to the operator after
+ successful execution of the pod
+ :type result: any
+ :param image_pull_policy: Specify a policy to cache or always pull an image
+ :type image_pull_policy: str
+ :param image_pull_secrets: Any image pull secrets to be given to the pod.
+ If more than one secret is required, provide a comma separated list:
+ secret_a,secret_b
+ :type image_pull_secrets: str
+ :param affinity: A dict containing a group of affinity scheduling rules
+ :type affinity: dict
+ :param hostnetwork: If True enable host networking on the pod
+ :type hostnetwork: bool
+ :param tolerations: A list of kubernetes tolerations
+ :type tolerations: list
+ :param security_context: A dict containing the security context for the pod
+ :type security_context: dict
+ :param configmaps: A list containing names of configmaps object
+ mounting env variables to the pod
+ :type configmaps: list[str]
+ :param pod_runtime_info_envs: environment variables about
+ pod runtime information (ip, namespace, nodeName, podName)
+ :type pod_runtime_info_envs: list[PodRuntimeEnv]
+ :param dnspolicy: Specify a dnspolicy for the pod
+ :type dnspolicy: str
+ """
+ def __init__(
+ self,
+ image,
+ envs,
+ cmds,
+ args=None,
+ secrets=None,
+ labels=None,
+ node_selectors=None,
+ name=None,
+ ports=None,
+ volumes=None,
+ volume_mounts=None,
+ namespace='default',
+ result=None,
+ image_pull_policy='IfNotPresent',
+ image_pull_secrets=None,
+ init_containers=None,
+ service_account_name=None,
+ resources=None,
+ annotations=None,
+ affinity=None,
+ hostnetwork=False,
+ tolerations=None,
+ security_context=None,
+ configmaps=None,
+ pod_runtime_info_envs=None,
+ dnspolicy=None
+ ):
+ self.image = image
+ self.envs = envs or {}
+ self.cmds = cmds
+ self.args = args or []
+ self.secrets = secrets or []
+ self.result = result
+ self.labels = labels or {}
+ self.name = name
+ self.ports = ports or []
+ self.volumes = volumes or []
+ self.volume_mounts = volume_mounts or []
+ self.node_selectors = node_selectors or {}
+ self.namespace = namespace
+ self.image_pull_policy = image_pull_policy
+ self.image_pull_secrets = image_pull_secrets
+ self.init_containers = init_containers
+ self.service_account_name = service_account_name
+ self.resources = resources or Resources()
+ self.annotations = annotations or {}
+ self.affinity = affinity or {}
+ self.hostnetwork = hostnetwork or False
+ self.tolerations = tolerations or []
+ self.security_context = security_context
+ self.configmaps = configmaps or []
+ self.pod_runtime_info_envs = pod_runtime_info_envs or []
+ self.dnspolicy = dnspolicy
+
+ def to_v1_kubernetes_pod(self):
+ """
+ Convert to support k8s V1Pod
+
+ :return: k8s.V1Pod
+ """
+ meta = k8s.V1ObjectMeta(
+ labels=self.labels,
+ name=self.name,
+ namespace=self.namespace,
+ )
+ spec = k8s.V1PodSpec(
+ init_containers=self.init_containers,
+ containers=[
+ k8s.V1Container(
+ image=self.image,
+ command=self.cmds,
+ name="base",
+ env=[k8s.V1EnvVar(name=key, value=val) for key, val in self.envs.items()],
+ args=self.args,
+ image_pull_policy=self.image_pull_policy,
+ )
+ ],
+ image_pull_secrets=self.image_pull_secrets,
+ service_account_name=self.service_account_name,
+ dns_policy=self.dnspolicy,
+ host_network=self.hostnetwork,
+ tolerations=self.tolerations,
+ security_context=self.security_context,
+ )
+
+ pod = k8s.V1Pod(
+ spec=spec,
+ metadata=meta,
+ )
+ for port in self.ports:
+ pod = port.attach_to_pod(pod)
+ for volume in self.volumes:
+ pod = volume.attach_to_pod(pod)
+ for volume_mount in self.volume_mounts:
+ pod = volume_mount.attach_to_pod(pod)
+ for secret in self.secrets:
+ pod = secret.attach_to_pod(pod)
+ for runtime_info in self.pod_runtime_info_envs:
+ pod = runtime_info.attach_to_pod(pod)
+ pod = self.resources.attach_to_pod(pod)
+ return pod
+
+ def as_dict(self):
+ res = self.__dict__
+ res['resources'] = res['resources'].as_dict()
+ res['ports'] = [port.as_dict() for port in res['ports']]
+ res['volume_mounts'] = [volume_mount.as_dict() for volume_mount in res['volume_mounts']]
+ res['volumes'] = [volume.as_dict() for volume in res['volumes']]
+
+ return res
diff --git a/docs/conf.py b/docs/conf.py
index 6df66f8..d18b6ea 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -201,6 +201,7 @@ exclude_patterns = [
'_api/airflow/example_dags',
'_api/airflow/index.rst',
'_api/airflow/jobs',
+ '_api/airflow/kubernetes_deprecated',
'_api/airflow/lineage',
'_api/airflow/logging_config',
'_api/airflow/macros',
diff --git a/tests/kubernetes/models/test_pod.py b/tests/kubernetes/models/test_pod.py
index 45c32aa..096b5f0 100644
--- a/tests/kubernetes/models/test_pod.py
+++ b/tests/kubernetes/models/test_pod.py
@@ -74,3 +74,84 @@ class TestPod(unittest.TestCase):
'volumes': []
}
}, result)
+
+ def test_to_v1_pod(self):
+ from airflow.kubernetes_deprecated.pod import Pod as DeprecatedPod
+ from airflow.kubernetes.volume import Volume
+ from airflow.kubernetes.volume_mount import VolumeMount
+ from airflow.kubernetes.pod import Resources
+
+ pod = DeprecatedPod(
+ image="foo",
+ name="bar",
+ namespace="baz",
+ image_pull_policy="Never",
+ envs={"test_key": "test_value"},
+ cmds=["airflow"],
+ resources=Resources(
+ request_memory="1G",
+ request_cpu="100Mi",
+ limit_gpu="100G"
+ ),
+ volumes=[Volume(name="foo", configs={})],
+ volume_mounts=[VolumeMount(name="foo", mount_path="/mnt", sub_path="/", read_only=True)]
+ )
+
+ k8s_client = ApiClient()
+
+ result = pod.to_v1_kubernetes_pod()
+ result = k8s_client.sanitize_for_serialization(result)
+
+ expected = \
+ {
+ 'metadata':
+ {
+ 'labels': {},
+ 'name': 'bar',
+ 'namespace': 'baz'
+ },
+ 'spec':
+ {'containers':
+ [
+ {
+ 'args': [],
+ 'command': ['airflow'],
+ 'env': [{'name': 'test_key', 'value': 'test_value'}],
+ 'image': 'foo',
+ 'imagePullPolicy': 'Never',
+ 'name': 'base',
+ 'volumeMounts':
+ [
+ {
+ 'mountPath': '/mnt',
+ 'name': 'foo',
+ 'readOnly': True, 'subPath': '/'
+ }
+ ], # noqa
+ 'resources':
+ {
+ 'limits':
+ {
+ 'cpu': None,
+ 'memory': None,
+ 'nvidia.com/gpu': '100G',
+ 'ephemeral-storage': None
+ },
+ 'requests':
+ {
+ 'cpu': '100Mi',
+ 'memory': '1G',
+ 'ephemeral-storage': None
+ }
+ }
+ }
+ ],
+ 'hostNetwork': False,
+ 'tolerations': [],
+ 'volumes': [
+ {'name': 'foo'}
+ ]
+ }
+ }
+ self.maxDiff = None
+ self.assertEquals(expected, result)
diff --git a/tests/kubernetes/test_pod_launcher_helper.py b/tests/kubernetes/test_pod_launcher_helper.py
new file mode 100644
index 0000000..a308ac3
--- /dev/null
+++ b/tests/kubernetes/test_pod_launcher_helper.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import unittest
+
+from airflow.kubernetes.pod import Port
+from airflow.kubernetes.volume_mount import VolumeMount
+from airflow.kubernetes.volume import Volume
+from airflow.kubernetes.pod_launcher_helper import convert_to_airflow_pod
+from airflow.kubernetes_deprecated.pod import Pod
+import kubernetes.client.models as k8s
+
+
+class TestPodLauncherHelper(unittest.TestCase):
+ def test_convert_to_airflow_pod(self):
+ input_pod = k8s.V1Pod(
+ metadata=k8s.V1ObjectMeta(
+ name="foo",
+ namespace="bar"
+ ),
+ spec=k8s.V1PodSpec(
+ containers=[
+ k8s.V1Container(
+ name="base",
+ command="foo",
+ image="myimage",
+ ports=[
+ k8s.V1ContainerPort(
+ name="myport",
+ container_port=8080,
+ )
+ ],
+ volume_mounts=[k8s.V1VolumeMount(
+ name="mymount",
+ mount_path="/tmp/mount",
+ read_only="True"
+ )]
+ )
+ ],
+ volumes=[
+ k8s.V1Volume(
+ name="myvolume"
+ )
+ ]
+ )
+ )
+ result_pod = convert_to_airflow_pod(input_pod)
+
+ expected = Pod(
+ name="foo",
+ namespace="bar",
+ envs={},
+ cmds=[],
+ image="myimage",
+ ports=[
+ Port(name="myport", container_port=8080)
+ ],
+ volume_mounts=[VolumeMount(
+ name="mymount",
+ mount_path="/tmp/mount",
+ sub_path=None,
+ read_only="True"
+ )],
+ volumes=[Volume(name="myvolume", configs={'name': 'myvolume'})]
+ )
+ expected_dict = expected.as_dict()
+ result_dict = result_pod.as_dict()
+ parsed_configs = self.pull_out_volumes(result_dict)
+ result_dict['volumes'] = parsed_configs
+ self.maxDiff = None
+
+ self.assertDictEqual(expected_dict, result_dict)
+
+ def pull_out_volumes(self, result_dict):
+ parsed_configs = []
+ for volume in result_dict['volumes']:
+ vol = {'name': volume['name']}
+ confs = {}
+ for k, v in volume['configs'].items():
+ if v and k[0] != '_':
+ confs[k] = v
+ vol['configs'] = confs
+ parsed_configs.append(vol)
+ return parsed_configs
diff --git a/tests/test_local_settings.py b/tests/test_local_settings.py
index 3497ee2..0e45ad8 100644
--- a/tests/test_local_settings.py
+++ b/tests/test_local_settings.py
@@ -21,6 +21,7 @@ import os
import sys
import tempfile
import unittest
+from airflow.kubernetes import pod_generator
from tests.compat import MagicMock, Mock, call, patch
@@ -40,8 +41,26 @@ def not_policy():
"""
SETTINGS_FILE_POD_MUTATION_HOOK = """
+from airflow.kubernetes.volume import Volume
+from airflow.kubernetes.pod import Port, Resources
+
def pod_mutation_hook(pod):
pod.namespace = 'airflow-tests'
+ pod.image = 'my_image'
+ pod.volumes.append(Volume(name="bar", configs={}))
+ pod.ports = [Port(container_port=8080)]
+ pod.resources = Resources(
+ request_memory="2G",
+ request_cpu="200Mi",
+ limit_gpu="200G"
+ )
+
+"""
+
+SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD = """
+def pod_mutation_hook(pod):
+ pod.spec.containers[0].image = "test-image"
+
"""
@@ -148,9 +167,86 @@ class LocalSettingsTest(unittest.TestCase):
settings.import_local_settings() # pylint: ignore
pod = MagicMock()
+ pod.volumes = []
settings.pod_mutation_hook(pod)
assert pod.namespace == 'airflow-tests'
+ self.assertEqual(pod.volumes[0].name, "bar")
+
+ def test_pod_mutation_to_k8s_pod(self):
+ with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"):
+ from airflow import settings
+ settings.import_local_settings() # pylint: ignore
+ from airflow.kubernetes.pod_launcher import PodLauncher
+
+ self.mock_kube_client = Mock()
+ self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
+ pod = pod_generator.PodGenerator(
+ image="foo",
+ name="bar",
+ namespace="baz",
+ image_pull_policy="Never",
+ cmds=["foo"],
+ volume_mounts=[
+ {"name": "foo", "mount_path": "/mnt", "sub_path": "/", "read_only": "True"}
+ ],
+ volumes=[{"name": "foo"}]
+ ).gen_pod()
+
+ self.assertEqual(pod.metadata.namespace, "baz")
+ self.assertEqual(pod.spec.containers[0].image, "foo")
+ self.assertEqual(pod.spec.volumes, [{'name': 'foo'}])
+ self.assertEqual(pod.spec.containers[0].ports, [])
+ self.assertEqual(pod.spec.containers[0].resources, None)
+
+ pod = self.pod_launcher._mutate_pod_backcompat(pod)
+
+ self.assertEqual(pod.metadata.namespace, "airflow-tests")
+ self.assertEqual(pod.spec.containers[0].image, "my_image")
+ self.assertEqual(pod.spec.volumes, [{'name': 'foo'}, {'name': 'bar'}])
+ self.maxDiff = None
+ self.assertEqual(
+ pod.spec.containers[0].ports[0].to_dict(),
+ {
+ "container_port": 8080,
+ "host_ip": None,
+ "host_port": None,
+ "name": None,
+ "protocol": None
+ }
+ )
+ self.assertEqual(
+ pod.spec.containers[0].resources.to_dict(),
+ {
+ 'limits': {
+ 'cpu': None,
+ 'memory': None,
+ 'ephemeral-storage': None,
+ 'nvidia.com/gpu': '200G'},
+ 'requests': {'cpu': '200Mi', 'ephemeral-storage': None, 'memory': '2G'}
+ }
+ )
+
+ def test_pod_mutation_v1_pod(self):
+ with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD, "airflow_local_settings"):
+ from airflow import settings
+ settings.import_local_settings() # pylint: ignore
+ from airflow.kubernetes.pod_launcher import PodLauncher
+
+ self.mock_kube_client = Mock()
+ self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
+ pod = pod_generator.PodGenerator(
+ image="myimage",
+ cmds=["foo"],
+ volume_mounts={
+ "name": "foo", "mount_path": "/mnt", "sub_path": "/", "read_only": "True"
+ },
+ volumes=[{"name": "foo"}]
+ ).gen_pod()
+
+ self.assertEqual(pod.spec.containers[0].image, "myimage")
+ pod = self.pod_launcher._mutate_pod_backcompat(pod)
+ self.assertEqual(pod.spec.containers[0].image, "test-image")
class TestStatsWithAllowList(unittest.TestCase):