You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/10/03 19:48:19 UTC
[airflow] 06/14: Allow overrides for pod_template_file (#11162)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 00203dbbd1fc372d3770f0ec858d95b4330a0cfa
Author: Daniel Imberman <da...@gmail.com>
AuthorDate: Sun Sep 27 14:39:35 2020 -0700
Allow overrides for pod_template_file (#11162)
* Allow overrides for pod_template_file
A pod_template_file should be treated as a *template* not a steadfast
rule.
This PR ensures that users can override individual values set by the
pod_template_file s.t. the same file can be used for multiple tasks.
* fix podtemplatetest
* fix name
(cherry picked from commit a888198c27bcdbc4538c02360c308ffcaca182fa)
---
.../contrib/operators/kubernetes_pod_operator.py | 33 ++++--
airflow/kubernetes/pod_generator.py | 48 ---------
kubernetes_tests/test_kubernetes_pod_operator.py | 116 ++++++++++++++-------
tests/kubernetes/test_pod_generator.py | 16 +--
4 files changed, 108 insertions(+), 105 deletions(-)
diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py
index cdf5076..7754fd7 100644
--- a/airflow/contrib/operators/kubernetes_pod_operator.py
+++ b/airflow/contrib/operators/kubernetes_pod_operator.py
@@ -20,14 +20,16 @@ import re
import yaml
from airflow.exceptions import AirflowException
-from airflow.kubernetes import kube_client, pod_generator, pod_launcher
from airflow.kubernetes.k8s_model import append_to_pod
+from airflow.kubernetes import kube_client, pod_generator, pod_launcher
from airflow.kubernetes.pod import Resources
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.utils.helpers import validate_key
from airflow.utils.state import State
from airflow.version import version as airflow_version
+from airflow.kubernetes.pod_generator import PodGenerator
+from kubernetes.client import models as k8s
class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
@@ -218,8 +220,9 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
self.annotations = annotations or {}
self.affinity = affinity or {}
self.resources = self._set_resources(resources) # noqa
+ self.k8s_resources = self.resources
self.config_file = config_file
- self.image_pull_secrets = image_pull_secrets
+ self.image_pull_secrets = image_pull_secrets or []
self.service_account_name = service_account_name
self.is_delete_operator_pod = is_delete_operator_pod
self.hostnetwork = hostnetwork
@@ -272,6 +275,9 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
client = kube_client.get_kube_client(cluster_context=self.cluster_context,
config_file=self.config_file)
+ self.pod = self.create_pod_request_obj()
+ self.namespace = self.pod.metadata.namespace
+
self.client = client
# Add combination of labels to uniquely identify a running pod
@@ -356,6 +362,11 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
Creates a V1Pod based on user parameters. Note that a `pod` or `pod_template_file`
will supersede all other values.
"""
+ if self.pod_template_file:
+ pod_template = pod_generator.PodGenerator.deserialize_model_file(self.pod_template_file)
+ else:
+ pod_template = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="name"))
+
pod = pod_generator.PodGenerator(
image=self.image,
namespace=self.namespace,
@@ -373,15 +384,12 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
service_account_name=self.service_account_name,
hostnetwork=self.hostnetwork,
tolerations=self.tolerations,
- configmaps=self.configmaps,
security_context=self.security_context,
dnspolicy=self.dnspolicy,
init_containers=self.init_containers,
restart_policy='Never',
schedulername=self.schedulername,
- pod_template_file=self.pod_template_file,
priority_class_name=self.priority_class_name,
- pod=self.full_pod_spec,
).gen_pod()
# noinspection PyTypeChecker
@@ -395,6 +403,17 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
self.volume_mounts # type: ignore
)
+ env_from = pod.spec.containers[0].env_from or []
+ for configmap in self.configmaps:
+ env_from.append(k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap)))
+ pod.spec.containers[0].env_from = env_from
+
+ if self.full_pod_spec:
+ pod_template = PodGenerator.reconcile_pods(pod_template, self.full_pod_spec)
+ pod = PodGenerator.reconcile_pods(pod_template, pod)
+
+ # if self.do_xcom_push:
+ # pod = PodGenerator.add_sidecar(pod)
return pod
def create_new_pod_for_operator(self, labels, launcher):
@@ -435,9 +454,9 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-
def monitor_launched_pod(self, launcher, pod):
"""
- Montitors a pod to completion that was created by a previous KubernetesPodOperator
+ Monitors a pod to completion that was created by a previous KubernetesPodOperator
- @param launcher: pod launcher that will manage launching and monitoring pods
+ :param launcher: pod launcher that will manage launching and monitoring pods
:param pod: podspec used to find pod using k8s API
:return:
"""
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index ed518d1..4fbfec1 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -24,11 +24,6 @@ is supported and no serialization need be written.
import copy
import hashlib
import re
-try:
- from inspect import signature
-except ImportError:
- # Python 2.7
- from funcsigs import signature # type: ignore
import os
import uuid
from functools import reduce
@@ -203,7 +198,6 @@ class PodGenerator(object):
pod_template_file=None,
extract_xcom=False,
):
- self.validate_pod_generator_args(locals())
if pod_template_file:
self.ud_pod = self.deserialize_model_file(pod_template_file)
@@ -556,48 +550,6 @@ class PodGenerator(object):
# pylint: disable=protected-access
return api_client._ApiClient__deserialize_model(pod, k8s.V1Pod)
- @staticmethod
- def validate_pod_generator_args(given_args):
- """
- :param given_args: The arguments passed to the PodGenerator constructor.
- :type given_args: dict
- :return: None
-
- Validate that if `pod` or `pod_template_file` are set that the user is not attempting
- to configure the pod with the other arguments.
- """
- pod_args = list(signature(PodGenerator).parameters.items())
-
- def predicate(k, v):
- """
- :param k: an arg to PodGenerator
- :type k: string
- :param v: the parameter of the given arg
- :type v: inspect.Parameter
- :return: bool
-
- returns True if the PodGenerator argument has no default arguments
- or the default argument is None, and it is not one of the listed field
- in `non_empty_fields`.
- """
- non_empty_fields = {
- 'pod', 'pod_template_file', 'extract_xcom', 'service_account_name', 'image_pull_policy',
- 'restart_policy'
- }
-
- return (v.default is None or v.default is v.empty) and k not in non_empty_fields
-
- args_without_defaults = {k: given_args[k] for k, v in pod_args if predicate(k, v) and given_args[k]}
-
- if given_args['pod'] and given_args['pod_template_file']:
- raise AirflowConfigException("Cannot pass both `pod` and `pod_template_file` arguments")
- if args_without_defaults and (given_args['pod'] or given_args['pod_template_file']):
- raise AirflowConfigException(
- "Cannot configure pod and pass either `pod` or `pod_template_file`. Fields {} passed.".format(
- list(args_without_defaults.keys())
- )
- )
-
def merge_objects(base_obj, client_obj):
"""
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py
index 0335b58..7a8674a 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -17,10 +17,12 @@
# under the License.
import json
+import logging
import os
import shutil
import sys
import unittest
+import textwrap
import kubernetes.client.models as k8s
import pendulum
@@ -834,6 +836,24 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
self.assertIsNotNone(result)
self.assertDictEqual(result, {"hello": "world"})
+ def test_pod_template_file_with_overrides_system(self):
+ fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
+ k = KubernetesPodOperator(
+ task_id="task" + self.get_current_task_name(),
+ labels={"foo": "bar", "fizz": "buzz"},
+ env_vars={"env_name": "value"},
+ in_cluster=False,
+ pod_template_file=fixture,
+ do_xcom_push=True
+ )
+
+ context = create_context(k)
+ result = k.execute(context)
+ self.assertIsNotNone(result)
+ self.assertEqual(k.pod.metadata.labels, {'fizz': 'buzz', 'foo': 'bar'})
+ self.assertEqual(k.pod.spec.containers[0].env, [k8s.V1EnvVar(name="env_name", value="value")])
+ self.assertDictEqual(result, {"hello": "world"})
+
def test_init_container(self):
# GIVEN
volume_mounts = [k8s.V1VolumeMount(
@@ -917,48 +937,72 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
def test_pod_template_file(self, mock_client, monitor_mock, start_mock):
from airflow.utils.state import State
+ fixture = sys.path[0] + '/tests/kubernetes/pod.yaml'
k = KubernetesPodOperator(
task_id='task',
- pod_template_file='tests/kubernetes/pod.yaml',
+ pod_template_file=fixture,
do_xcom_push=True
)
monitor_mock.return_value = (State.SUCCESS, None)
- context = self.create_context(k)
- k.execute(context)
+ context = create_context(k)
+ with self.assertLogs(k.log, level=logging.DEBUG) as cm:
+ k.execute(context)
+ expected_line = textwrap.dedent("""\
+ DEBUG:airflow.task.operators:Starting pod:
+ api_version: v1
+ kind: Pod
+ metadata:
+ annotations: {}
+ cluster_name: null
+ creation_timestamp: null
+ deletion_grace_period_seconds: null\
+ """).strip()
+ self.assertTrue(any(line.startswith(expected_line) for line in cm.output))
+
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
- self.assertEqual({
- 'apiVersion': 'v1',
- 'kind': 'Pod',
- 'metadata': {'name': mock.ANY, 'namespace': 'mem-example'},
- 'spec': {
- 'volumes': [{'name': 'xcom', 'emptyDir': {}}],
- 'containers': [{
- 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
- 'command': ['stress'],
- 'image': 'apache/airflow:stress-2020.07.10-1.0.4',
- 'name': 'memory-demo-ctr',
- 'resources': {
- 'limits': {'memory': '200Mi'},
- 'requests': {'memory': '100Mi'}
- },
- 'volumeMounts': [{
- 'name': 'xcom',
- 'mountPath': '/airflow/xcom'
- }]
- }, {
- 'name': 'airflow-xcom-sidecar',
- 'image': "alpine",
- 'command': ['sh', '-c', PodDefaults.XCOM_CMD],
- 'volumeMounts': [
- {
- 'name': 'xcom',
- 'mountPath': '/airflow/xcom'
- }
- ],
- 'resources': {'requests': {'cpu': '1m'}},
- }],
- }
- }, actual_pod)
+ expected_dict = {'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {'annotations': {},
+ 'labels': {},
+ 'name': 'memory-demo',
+ 'namespace': 'mem-example'},
+ 'spec': {'affinity': {},
+ 'containers': [{'args': ['--vm',
+ '1',
+ '--vm-bytes',
+ '150M',
+ '--vm-hang',
+ '1'],
+ 'command': ['stress'],
+ 'env': [],
+ 'envFrom': [],
+ 'image': 'apache/airflow:stress-2020.07.10-1.0.4',
+ 'imagePullPolicy': 'IfNotPresent',
+ 'name': 'base',
+ 'ports': [],
+ 'resources': {'limits': {'memory': '200Mi'},
+ 'requests': {'memory': '100Mi'}},
+ 'volumeMounts': [{'mountPath': '/airflow/xcom',
+ 'name': 'xcom'}]},
+ {'command': ['sh',
+ '-c',
+ 'trap "exit 0" INT; while true; do sleep '
+ '30; done;'],
+ 'image': 'alpine',
+ 'name': 'airflow-xcom-sidecar',
+ 'resources': {'requests': {'cpu': '1m'}},
+ 'volumeMounts': [{'mountPath': '/airflow/xcom',
+ 'name': 'xcom'}]}],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'initContainers': [],
+ 'nodeSelector': {},
+ 'restartPolicy': 'Never',
+ 'securityContext': {},
+ 'serviceAccountName': 'default',
+ 'tolerations': [],
+ 'volumes': [{'emptyDir': {}, 'name': 'xcom'}]}}
+ self.assertEqual(expected_dict, actual_pod)
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py
index 5243673..0c9d722 100644
--- a/tests/kubernetes/test_pod_generator.py
+++ b/tests/kubernetes/test_pod_generator.py
@@ -16,12 +16,12 @@
# under the License.
import unittest
+import sys
from tests.compat import mock
import uuid
import kubernetes.client.models as k8s
from kubernetes.client import ApiClient
-from airflow.exceptions import AirflowConfigException
from airflow.kubernetes.k8s_model import append_to_pod
from airflow.kubernetes.pod import Resources
from airflow.kubernetes.pod_generator import PodDefaults, PodGenerator, extend_object_field, merge_objects
@@ -1045,7 +1045,7 @@ class TestPodGenerator(unittest.TestCase):
self.assertEqual(client_spec, res)
def test_deserialize_model_file(self):
- fixture = 'tests/kubernetes/pod.yaml'
+ fixture = sys.path[0] + '/tests/kubernetes/pod.yaml'
result = PodGenerator.deserialize_model_file(fixture)
sanitized_res = self.k8s_client.sanitize_for_serialization(result)
self.assertEqual(sanitized_res, self.deserialize_result)
@@ -1073,18 +1073,6 @@ spec:
sanitized_res = self.k8s_client.sanitize_for_serialization(result)
self.assertEqual(sanitized_res, self.deserialize_result)
- def test_validate_pod_generator(self):
- with self.assertRaises(AirflowConfigException):
- PodGenerator(image='k', pod=k8s.V1Pod())
- with self.assertRaises(AirflowConfigException):
- PodGenerator(pod=k8s.V1Pod(), pod_template_file='k')
- with self.assertRaises(AirflowConfigException):
- PodGenerator(image='k', pod_template_file='k')
-
- PodGenerator(image='k')
- PodGenerator(pod_template_file='tests/kubernetes/pod.yaml')
- PodGenerator(pod=k8s.V1Pod())
-
def test_add_custom_label(self):
from kubernetes.client import models as k8s