You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/10/03 19:48:19 UTC

[airflow] 06/14: Allow overrides for pod_template_file (#11162)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 00203dbbd1fc372d3770f0ec858d95b4330a0cfa
Author: Daniel Imberman <da...@gmail.com>
AuthorDate: Sun Sep 27 14:39:35 2020 -0700

    Allow overrides for pod_template_file (#11162)
    
    * Allow overrides for pod_template_file
    
    A pod_template_file should be treated as a *template* not a steadfast
    rule.
    
    This PR ensures that users can override individual values set by the
    pod_template_file s.t. the same file can be used for multiple tasks.
    
    * fix podtemplatetest
    
    * fix name
    
    (cherry picked from commit a888198c27bcdbc4538c02360c308ffcaca182fa)
---
 .../contrib/operators/kubernetes_pod_operator.py   |  33 ++++--
 airflow/kubernetes/pod_generator.py                |  48 ---------
 kubernetes_tests/test_kubernetes_pod_operator.py   | 116 ++++++++++++++-------
 tests/kubernetes/test_pod_generator.py             |  16 +--
 4 files changed, 108 insertions(+), 105 deletions(-)

diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py
index cdf5076..7754fd7 100644
--- a/airflow/contrib/operators/kubernetes_pod_operator.py
+++ b/airflow/contrib/operators/kubernetes_pod_operator.py
@@ -20,14 +20,16 @@ import re
 import yaml
 
 from airflow.exceptions import AirflowException
-from airflow.kubernetes import kube_client, pod_generator, pod_launcher
 from airflow.kubernetes.k8s_model import append_to_pod
+from airflow.kubernetes import kube_client, pod_generator, pod_launcher
 from airflow.kubernetes.pod import Resources
 from airflow.models import BaseOperator
 from airflow.utils.decorators import apply_defaults
 from airflow.utils.helpers import validate_key
 from airflow.utils.state import State
 from airflow.version import version as airflow_version
+from airflow.kubernetes.pod_generator import PodGenerator
+from kubernetes.client import models as k8s
 
 
 class KubernetesPodOperator(BaseOperator):  # pylint: disable=too-many-instance-attributes
@@ -218,8 +220,9 @@ class KubernetesPodOperator(BaseOperator):  # pylint: disable=too-many-instance-
         self.annotations = annotations or {}
         self.affinity = affinity or {}
         self.resources = self._set_resources(resources)  # noqa
+        self.k8s_resources = self.resources
         self.config_file = config_file
-        self.image_pull_secrets = image_pull_secrets
+        self.image_pull_secrets = image_pull_secrets or []
         self.service_account_name = service_account_name
         self.is_delete_operator_pod = is_delete_operator_pod
         self.hostnetwork = hostnetwork
@@ -272,6 +275,9 @@ class KubernetesPodOperator(BaseOperator):  # pylint: disable=too-many-instance-
                 client = kube_client.get_kube_client(cluster_context=self.cluster_context,
                                                      config_file=self.config_file)
 
+            self.pod = self.create_pod_request_obj()
+            self.namespace = self.pod.metadata.namespace
+
             self.client = client
 
             # Add combination of labels to uniquely identify a running pod
@@ -356,6 +362,11 @@ class KubernetesPodOperator(BaseOperator):  # pylint: disable=too-many-instance-
         Creates a V1Pod based on user parameters. Note that a `pod` or `pod_template_file`
         will supersede all other values.
         """
+        if self.pod_template_file:
+            pod_template = pod_generator.PodGenerator.deserialize_model_file(self.pod_template_file)
+        else:
+            pod_template = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="name"))
+
         pod = pod_generator.PodGenerator(
             image=self.image,
             namespace=self.namespace,
@@ -373,15 +384,12 @@ class KubernetesPodOperator(BaseOperator):  # pylint: disable=too-many-instance-
             service_account_name=self.service_account_name,
             hostnetwork=self.hostnetwork,
             tolerations=self.tolerations,
-            configmaps=self.configmaps,
             security_context=self.security_context,
             dnspolicy=self.dnspolicy,
             init_containers=self.init_containers,
             restart_policy='Never',
             schedulername=self.schedulername,
-            pod_template_file=self.pod_template_file,
             priority_class_name=self.priority_class_name,
-            pod=self.full_pod_spec,
         ).gen_pod()
 
         # noinspection PyTypeChecker
@@ -395,6 +403,17 @@ class KubernetesPodOperator(BaseOperator):  # pylint: disable=too-many-instance-
             self.volume_mounts  # type: ignore
         )
 
+        env_from = pod.spec.containers[0].env_from or []
+        for configmap in self.configmaps:
+            env_from.append(k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap)))
+        pod.spec.containers[0].env_from = env_from
+
+        if self.full_pod_spec:
+            pod_template = PodGenerator.reconcile_pods(pod_template, self.full_pod_spec)
+        pod = PodGenerator.reconcile_pods(pod_template, pod)
+
+        # if self.do_xcom_push:
+        #     pod = PodGenerator.add_sidecar(pod)
         return pod
 
     def create_new_pod_for_operator(self, labels, launcher):
@@ -435,9 +454,9 @@ class KubernetesPodOperator(BaseOperator):  # pylint: disable=too-many-instance-
 
     def monitor_launched_pod(self, launcher, pod):
         """
-        Montitors a pod to completion that was created by a previous KubernetesPodOperator
+        Monitors a pod to completion that was created by a previous KubernetesPodOperator
 
-        @param launcher: pod launcher that will manage launching and monitoring pods
+        :param launcher: pod launcher that will manage launching and monitoring pods
         :param pod: podspec used to find pod using k8s API
         :return:
         """
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index ed518d1..4fbfec1 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -24,11 +24,6 @@ is supported and no serialization need be written.
 import copy
 import hashlib
 import re
-try:
-    from inspect import signature
-except ImportError:
-    # Python 2.7
-    from funcsigs import signature  # type: ignore
 import os
 import uuid
 from functools import reduce
@@ -203,7 +198,6 @@ class PodGenerator(object):
         pod_template_file=None,
         extract_xcom=False,
     ):
-        self.validate_pod_generator_args(locals())
 
         if pod_template_file:
             self.ud_pod = self.deserialize_model_file(pod_template_file)
@@ -556,48 +550,6 @@ class PodGenerator(object):
         # pylint: disable=protected-access
         return api_client._ApiClient__deserialize_model(pod, k8s.V1Pod)
 
-    @staticmethod
-    def validate_pod_generator_args(given_args):
-        """
-        :param given_args: The arguments passed to the PodGenerator constructor.
-        :type given_args: dict
-        :return: None
-
-        Validate that if `pod` or `pod_template_file` are set that the user is not attempting
-        to configure the pod with the other arguments.
-        """
-        pod_args = list(signature(PodGenerator).parameters.items())
-
-        def predicate(k, v):
-            """
-            :param k: an arg to PodGenerator
-            :type k: string
-            :param v: the parameter of the given arg
-            :type v: inspect.Parameter
-            :return: bool
-
-            returns True if the PodGenerator argument has no default arguments
-            or the default argument is None, and it is not one of the listed field
-            in `non_empty_fields`.
-            """
-            non_empty_fields = {
-                'pod', 'pod_template_file', 'extract_xcom', 'service_account_name', 'image_pull_policy',
-                'restart_policy'
-            }
-
-            return (v.default is None or v.default is v.empty) and k not in non_empty_fields
-
-        args_without_defaults = {k: given_args[k] for k, v in pod_args if predicate(k, v) and given_args[k]}
-
-        if given_args['pod'] and given_args['pod_template_file']:
-            raise AirflowConfigException("Cannot pass both `pod` and `pod_template_file` arguments")
-        if args_without_defaults and (given_args['pod'] or given_args['pod_template_file']):
-            raise AirflowConfigException(
-                "Cannot configure pod and pass either `pod` or `pod_template_file`. Fields {} passed.".format(
-                    list(args_without_defaults.keys())
-                )
-            )
-
 
 def merge_objects(base_obj, client_obj):
     """
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py
index 0335b58..7a8674a 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -17,10 +17,12 @@
 # under the License.
 
 import json
+import logging
 import os
 import shutil
 import sys
 import unittest
+import textwrap
 
 import kubernetes.client.models as k8s
 import pendulum
@@ -834,6 +836,24 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
         self.assertIsNotNone(result)
         self.assertDictEqual(result, {"hello": "world"})
 
+    def test_pod_template_file_with_overrides_system(self):
+        fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
+        k = KubernetesPodOperator(
+            task_id="task" + self.get_current_task_name(),
+            labels={"foo": "bar", "fizz": "buzz"},
+            env_vars={"env_name": "value"},
+            in_cluster=False,
+            pod_template_file=fixture,
+            do_xcom_push=True
+        )
+
+        context = create_context(k)
+        result = k.execute(context)
+        self.assertIsNotNone(result)
+        self.assertEqual(k.pod.metadata.labels, {'fizz': 'buzz', 'foo': 'bar'})
+        self.assertEqual(k.pod.spec.containers[0].env, [k8s.V1EnvVar(name="env_name", value="value")])
+        self.assertDictEqual(result, {"hello": "world"})
+
     def test_init_container(self):
         # GIVEN
         volume_mounts = [k8s.V1VolumeMount(
@@ -917,48 +937,72 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
     @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
     def test_pod_template_file(self, mock_client, monitor_mock, start_mock):
         from airflow.utils.state import State
+        fixture = sys.path[0] + '/tests/kubernetes/pod.yaml'
         k = KubernetesPodOperator(
             task_id='task',
-            pod_template_file='tests/kubernetes/pod.yaml',
+            pod_template_file=fixture,
             do_xcom_push=True
         )
         monitor_mock.return_value = (State.SUCCESS, None)
-        context = self.create_context(k)
-        k.execute(context)
+        context = create_context(k)
+        with self.assertLogs(k.log, level=logging.DEBUG) as cm:
+            k.execute(context)
+            expected_line = textwrap.dedent("""\
+            DEBUG:airflow.task.operators:Starting pod:
+            api_version: v1
+            kind: Pod
+            metadata:
+              annotations: {}
+              cluster_name: null
+              creation_timestamp: null
+              deletion_grace_period_seconds: null\
+            """).strip()
+            self.assertTrue(any(line.startswith(expected_line) for line in cm.output))
+
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
-        self.assertEqual({
-            'apiVersion': 'v1',
-            'kind': 'Pod',
-            'metadata': {'name': mock.ANY, 'namespace': 'mem-example'},
-            'spec': {
-                'volumes': [{'name': 'xcom', 'emptyDir': {}}],
-                'containers': [{
-                    'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
-                    'command': ['stress'],
-                    'image': 'apache/airflow:stress-2020.07.10-1.0.4',
-                    'name': 'memory-demo-ctr',
-                    'resources': {
-                        'limits': {'memory': '200Mi'},
-                        'requests': {'memory': '100Mi'}
-                    },
-                    'volumeMounts': [{
-                        'name': 'xcom',
-                        'mountPath': '/airflow/xcom'
-                    }]
-                }, {
-                    'name': 'airflow-xcom-sidecar',
-                    'image': "alpine",
-                    'command': ['sh', '-c', PodDefaults.XCOM_CMD],
-                    'volumeMounts': [
-                        {
-                            'name': 'xcom',
-                            'mountPath': '/airflow/xcom'
-                        }
-                    ],
-                    'resources': {'requests': {'cpu': '1m'}},
-                }],
-            }
-        }, actual_pod)
+        expected_dict = {'apiVersion': 'v1',
+                         'kind': 'Pod',
+                         'metadata': {'annotations': {},
+                                      'labels': {},
+                                      'name': 'memory-demo',
+                                      'namespace': 'mem-example'},
+                         'spec': {'affinity': {},
+                                  'containers': [{'args': ['--vm',
+                                                           '1',
+                                                           '--vm-bytes',
+                                                           '150M',
+                                                           '--vm-hang',
+                                                           '1'],
+                                                  'command': ['stress'],
+                                                  'env': [],
+                                                  'envFrom': [],
+                                                  'image': 'apache/airflow:stress-2020.07.10-1.0.4',
+                                                  'imagePullPolicy': 'IfNotPresent',
+                                                  'name': 'base',
+                                                  'ports': [],
+                                                  'resources': {'limits': {'memory': '200Mi'},
+                                                                'requests': {'memory': '100Mi'}},
+                                                  'volumeMounts': [{'mountPath': '/airflow/xcom',
+                                                                    'name': 'xcom'}]},
+                                                 {'command': ['sh',
+                                                              '-c',
+                                                              'trap "exit 0" INT; while true; do sleep '
+                                                              '30; done;'],
+                                                  'image': 'alpine',
+                                                  'name': 'airflow-xcom-sidecar',
+                                                  'resources': {'requests': {'cpu': '1m'}},
+                                                  'volumeMounts': [{'mountPath': '/airflow/xcom',
+                                                                    'name': 'xcom'}]}],
+                                  'hostNetwork': False,
+                                  'imagePullSecrets': [],
+                                  'initContainers': [],
+                                  'nodeSelector': {},
+                                  'restartPolicy': 'Never',
+                                  'securityContext': {},
+                                  'serviceAccountName': 'default',
+                                  'tolerations': [],
+                                  'volumes': [{'emptyDir': {}, 'name': 'xcom'}]}}
+        self.assertEqual(expected_dict, actual_pod)
 
     @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
     @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py
index 5243673..0c9d722 100644
--- a/tests/kubernetes/test_pod_generator.py
+++ b/tests/kubernetes/test_pod_generator.py
@@ -16,12 +16,12 @@
 # under the License.
 
 import unittest
+import sys
 from tests.compat import mock
 import uuid
 import kubernetes.client.models as k8s
 from kubernetes.client import ApiClient
 
-from airflow.exceptions import AirflowConfigException
 from airflow.kubernetes.k8s_model import append_to_pod
 from airflow.kubernetes.pod import Resources
 from airflow.kubernetes.pod_generator import PodDefaults, PodGenerator, extend_object_field, merge_objects
@@ -1045,7 +1045,7 @@ class TestPodGenerator(unittest.TestCase):
         self.assertEqual(client_spec, res)
 
     def test_deserialize_model_file(self):
-        fixture = 'tests/kubernetes/pod.yaml'
+        fixture = sys.path[0] + '/tests/kubernetes/pod.yaml'
         result = PodGenerator.deserialize_model_file(fixture)
         sanitized_res = self.k8s_client.sanitize_for_serialization(result)
         self.assertEqual(sanitized_res, self.deserialize_result)
@@ -1073,18 +1073,6 @@ spec:
         sanitized_res = self.k8s_client.sanitize_for_serialization(result)
         self.assertEqual(sanitized_res, self.deserialize_result)
 
-    def test_validate_pod_generator(self):
-        with self.assertRaises(AirflowConfigException):
-            PodGenerator(image='k', pod=k8s.V1Pod())
-        with self.assertRaises(AirflowConfigException):
-            PodGenerator(pod=k8s.V1Pod(), pod_template_file='k')
-        with self.assertRaises(AirflowConfigException):
-            PodGenerator(image='k', pod_template_file='k')
-
-        PodGenerator(image='k')
-        PodGenerator(pod_template_file='tests/kubernetes/pod.yaml')
-        PodGenerator(pod=k8s.V1Pod())
-
     def test_add_custom_label(self):
         from kubernetes.client import models as k8s