You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/06/29 13:21:37 UTC

[airflow] 31/37: [AIRFLOW-5413] Allow K8S worker pod to be configured from JSON/YAML file (#6230)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit b222e1717b8dd466bfb6880e0e079bc19f60e383
Author: Daniel Imberman <da...@astronomer.io>
AuthorDate: Fri Jun 26 17:55:00 2020 -0700

    [AIRFLOW-5413] Allow K8S worker pod to be configured from JSON/YAML file (#6230)
    
    * [AIRFLOW-5413] enable pod config from file
    
    * Update airflow/kubernetes/pod_generator.py
    
    Co-Authored-By: Ash Berlin-Taylor <as...@firemirror.com>
    
    * Update airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
    
    Co-Authored-By: Ash Berlin-Taylor <as...@firemirror.com>
    
    Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
    (cherry picked from commit 967930c0cb6e2293f2a49e5c9add5aa1917f3527)
---
 airflow/config_templates/config.yml                |   7 +
 airflow/config_templates/default_airflow.cfg       |   3 +
 .../contrib/operators/kubernetes_pod_operator.py   |  15 ++-
 .../example_kubernetes_executor_config.py          |   6 +-
 airflow/executors/kubernetes_executor.py           |  11 +-
 airflow/kubernetes/pod_generator.py                | 144 ++++++++++++++-------
 airflow/kubernetes/worker_configuration.py         |  20 ++-
 kubernetes_tests/test_kubernetes_pod_operator.py   |  50 ++++++-
 tests/executors/test_kubernetes_executor.py        |   2 +-
 tests/kubernetes/models/test_pod.py                |   2 -
 tests/kubernetes/models/test_secret.py             |   2 -
 tests/kubernetes/pod.yaml                          |  33 +++++
 tests/kubernetes/test_pod_generator.py             |  80 +++++++++---
 tests/kubernetes/test_worker_configuration.py      |  14 +-
 14 files changed, 304 insertions(+), 85 deletions(-)

diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 9b63200..61491d8 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -1736,6 +1736,13 @@
       type: string
       example: ~
       default: ""
+    - name: pod_template_file
+      description: |
+        Path to the YAML pod file. If set, all other kubernetes-related fields are ignored.
+      version_added: ~
+      type: string
+      example: ~
+      default: ""
     - name: worker_container_tag
       description: ~
       version_added: ~
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index ca9de12..2cc97e2 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -801,6 +801,9 @@ verify_certs = True
 [kubernetes]
 # The repository, tag and imagePullPolicy of the Kubernetes Image for the Worker to Run
 worker_container_repository =
+
+# Path to the YAML pod file. If set, all other kubernetes-related fields are ignored.
+pod_template_file =
 worker_container_tag =
 worker_container_image_pull_policy = IfNotPresent
 
diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py
index ce8f19c..41f0df3 100644
--- a/airflow/contrib/operators/kubernetes_pod_operator.py
+++ b/airflow/contrib/operators/kubernetes_pod_operator.py
@@ -130,8 +130,18 @@ class KubernetesPodOperator(BaseOperator):  # pylint: disable=too-many-instance-
     :type schedulername: str
     :param full_pod_spec: The complete podSpec
     :type full_pod_spec: kubernetes.client.models.V1Pod
+    :param init_containers: init container for the launched Pod
+    :type init_containers: list[kubernetes.client.models.V1Container]
+    :param log_events_on_failure: Log the pod's events if a failure occurs
+    :type log_events_on_failure: bool
+    :param do_xcom_push: If True, the content of the file
+        /airflow/xcom/return.json in the container will also be pushed to an
+        XCom when the container completes.
+    :type do_xcom_push: bool
+    :param pod_template_file: path to pod template file
+    :type pod_template_file: str
     """
-    template_fields = ('cmds', 'arguments', 'env_vars', 'config_file')
+    template_fields = ('cmds', 'arguments', 'env_vars', 'config_file', 'pod_template_file')
 
     @apply_defaults
     def __init__(self,  # pylint: disable=too-many-arguments,too-many-locals
@@ -215,8 +225,8 @@ class KubernetesPodOperator(BaseOperator):  # pylint: disable=too-many-instance-
         self.full_pod_spec = full_pod_spec
         self.init_containers = init_containers or []
         self.log_events_on_failure = log_events_on_failure
-        self.priority_class_name = priority_class_name
         self.pod_template_file = pod_template_file
+        self.priority_class_name = priority_class_name
         self.name = self._set_name(name)
 
     @staticmethod
@@ -348,6 +358,7 @@ class KubernetesPodOperator(BaseOperator):  # pylint: disable=too-many-instance-
             init_containers=self.init_containers,
             restart_policy='Never',
             schedulername=self.schedulername,
+            pod_template_file=self.pod_template_file,
             priority_class_name=self.priority_class_name,
             pod=self.full_pod_spec,
         ).gen_pod()
diff --git a/airflow/example_dags/example_kubernetes_executor_config.py b/airflow/example_dags/example_kubernetes_executor_config.py
index d740956..2e4ba00 100644
--- a/airflow/example_dags/example_kubernetes_executor_config.py
+++ b/airflow/example_dags/example_kubernetes_executor_config.py
@@ -83,14 +83,14 @@ with DAG(
         }
     )
 
-    # Test that we can run tasks as a normal user
+    # Test that we can add labels to pods
     third_task = PythonOperator(
         task_id="non_root_task",
         python_callable=print_stuff,
         executor_config={
             "KubernetesExecutor": {
-                "securityContext": {
-                    "runAsUser": 1000
+                "labels": {
+                    "release": "stable"
                 }
             }
         }
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index 98e3154..e014aa3 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -20,7 +20,6 @@ import json
 import multiprocessing
 import time
 from queue import Empty
-from uuid import uuid4
 
 import kubernetes
 from dateutil import parser
@@ -72,6 +71,9 @@ class KubeConfig:
         )
         self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {})
         self.kube_annotations = configuration_dict.get('kubernetes_annotations', {}) or None
+        self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file',
+                                          fallback=None)
+
         self.kube_labels = configuration_dict.get('kubernetes_labels', {})
         self.delete_worker_pods = conf.getboolean(
             self.kubernetes_section, 'delete_worker_pods')
@@ -220,6 +222,8 @@ class KubeConfig:
             return int(val)
 
     def _validate(self):
+        if self.pod_template_file:
+            return
         # TODO: use XOR for dags_volume_claim and git_dags_folder_mount_point
         if not self.dags_volume_claim \
             and not self.dags_volume_host \
@@ -498,10 +502,7 @@ class AirflowKubernetesScheduler(LoggingMixin):
             dag_id)
         safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
             task_id)
-        safe_uuid = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
-            uuid4().hex)
-        return AirflowKubernetesScheduler._make_safe_pod_id(safe_dag_id, safe_task_id,
-                                                            safe_uuid)
+        return safe_dag_id + safe_task_id
 
     @staticmethod
     def _label_safe_datestring_to_datetime(string):
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index 711b1a9..e46407b 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -24,10 +24,20 @@ is supported and no serialization need be written.
 import copy
 import hashlib
 import re
+try:
+    from inspect import signature
+except ImportError:
+    # Python 2.7
+    from funcsigs import signature  # type: ignore
+import os
 import uuid
+from functools import reduce
 
 import kubernetes.client.models as k8s
+import yaml
+from kubernetes.client.api_client import ApiClient
 
+from airflow.exceptions import AirflowConfigException
 from airflow.version import version as airflow_version
 
 MAX_LABEL_LEN = 63
@@ -35,10 +45,14 @@ MAX_LABEL_LEN = 63
 MAX_POD_ID_LEN = 253
 
 
-class PodDefaults:
+class PodDefaults(object):
     """
     Static defaults for the PodGenerator
     """
+
+    def __init__(self):
+        pass
+
     XCOM_MOUNT_PATH = '/airflow/xcom'
     SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar'
     XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;'
@@ -82,7 +96,7 @@ def make_safe_label_value(string):
     return safe_label
 
 
-class PodGenerator:
+class PodGenerator(object):
     """
     Contains Kubernetes Airflow Worker configuration logic
 
@@ -147,9 +161,11 @@ class PodGenerator:
     :param dnspolicy: Specify a dnspolicy for the pod
     :type dnspolicy: str
     :param schedulername: Specify a schedulername for the pod
-    :type schedulername: str
-    :param pod: The fully specified pod.
-    :type pod: kubernetes.client.models.V1Pod
+    :type schedulername: Optional[str]
+    :param pod: The fully specified pod. Mutually exclusive with `path_or_string`
+    :type pod: Optional[kubernetes.client.models.V1Pod]
+    :param pod_template_file: Path to YAML file. Mutually exclusive with `pod`
+    :type pod_template_file: Optional[str]
     :param extract_xcom: Whether to bring up a container for xcom
     :type extract_xcom: bool
     """
@@ -167,8 +183,8 @@ class PodGenerator:
         node_selectors=None,
         ports=None,
         volumes=None,
-        image_pull_policy='IfNotPresent',
-        restart_policy='Never',
+        image_pull_policy=None,
+        restart_policy=None,
         image_pull_secrets=None,
         init_containers=None,
         service_account_name=None,
@@ -183,9 +199,16 @@ class PodGenerator:
         schedulername=None,
         priority_class_name=None,
         pod=None,
+        pod_template_file=None,
         extract_xcom=False,
     ):
-        self.ud_pod = pod
+        self.validate_pod_generator_args(locals())
+
+        if pod_template_file:
+            self.ud_pod = self.deserialize_model_file(pod_template_file)
+        else:
+            self.ud_pod = pod
+
         self.pod = k8s.V1Pod()
         self.pod.api_version = 'v1'
         self.pod.kind = 'Pod'
@@ -348,37 +371,7 @@ class PodGenerator:
                 'iam.cloud.google.com/service-account': gcp_service_account_key
             })
 
-        pod_spec_generator = PodGenerator(
-            image=namespaced.get('image'),
-            envs=namespaced.get('env'),
-            cmds=namespaced.get('cmds'),
-            args=namespaced.get('args'),
-            labels=namespaced.get('labels'),
-            node_selectors=namespaced.get('node_selectors'),
-            name=namespaced.get('name'),
-            ports=namespaced.get('ports'),
-            volumes=namespaced.get('volumes'),
-            volume_mounts=namespaced.get('volume_mounts'),
-            namespace=namespaced.get('namespace'),
-            image_pull_policy=namespaced.get('image_pull_policy'),
-            restart_policy=namespaced.get('restart_policy'),
-            image_pull_secrets=namespaced.get('image_pull_secrets'),
-            init_containers=namespaced.get('init_containers'),
-            service_account_name=namespaced.get('service_account_name'),
-            resources=resources,
-            annotations=namespaced.get('annotations'),
-            affinity=namespaced.get('affinity'),
-            hostnetwork=namespaced.get('hostnetwork'),
-            tolerations=namespaced.get('tolerations'),
-            security_context=namespaced.get('security_context'),
-            configmaps=namespaced.get('configmaps'),
-            dnspolicy=namespaced.get('dnspolicy'),
-            schedulername=namespaced.get('schedulername'),
-            pod=namespaced.get('pod'),
-            extract_xcom=namespaced.get('extract_xcom'),
-        )
-
-        return pod_spec_generator.gen_pod()
+        return PodGenerator(**namespaced).gen_pod()
 
     @staticmethod
     def reconcile_pods(base_pod, client_pod):
@@ -495,12 +488,73 @@ class PodGenerator:
             name=pod_id
         ).gen_pod()
 
-        # Reconcile the pod generated by the Operator and the Pod
-        # generated by the .cfg file
-        pod_with_executor_config = PodGenerator.reconcile_pods(worker_config,
-                                                               kube_executor_config)
-        # Reconcile that pod with the dynamic fields.
-        return PodGenerator.reconcile_pods(pod_with_executor_config, dynamic_pod)
+        # Reconcile the pods starting with the first chronologically,
+        # Pod from the airflow.cfg -> Pod from executor_config arg -> Pod from the K8s executor
+        pod_list = [worker_config, kube_executor_config, dynamic_pod]
+
+        return reduce(PodGenerator.reconcile_pods, pod_list)
+
+    @staticmethod
+    def deserialize_model_file(path):
+        """
+        :param path: Path to the file
+        :return: a kubernetes.client.models.V1Pod
+
+        Unfortunately we need access to the private method
+        ``_ApiClient__deserialize_model`` from the kubernetes client.
+        This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.
+        """
+        api_client = ApiClient()
+        if os.path.exists(path):
+            with open(path) as stream:
+                pod = yaml.safe_load(stream)
+        else:
+            pod = yaml.safe_load(path)
+
+        # pylint: disable=protected-access
+        return api_client._ApiClient__deserialize_model(pod, k8s.V1Pod)
+
+    @staticmethod
+    def validate_pod_generator_args(given_args):
+        """
+        :param given_args: The arguments passed to the PodGenerator constructor.
+        :type given_args: dict
+        :return: None
+
+        Validate that if `pod` or `pod_template_file` are set that the user is not attempting
+        to configure the pod with the other arguments.
+        """
+        pod_args = list(signature(PodGenerator).parameters.items())
+
+        def predicate(k, v):
+            """
+            :param k: an arg to PodGenerator
+            :type k: string
+            :param v: the parameter of the given arg
+            :type v: inspect.Parameter
+            :return: bool
+
+            returns True if the PodGenerator argument has no default arguments
+            or the default argument is None, and it is not one of the listed field
+            in `non_empty_fields`.
+            """
+            non_empty_fields = {
+                'pod', 'pod_template_file', 'extract_xcom', 'service_account_name', 'image_pull_policy',
+                'restart_policy'
+            }
+
+            return (v.default is None or v.default is v.empty) and k not in non_empty_fields
+
+        args_without_defaults = {k: given_args[k] for k, v in pod_args if predicate(k, v) and given_args[k]}
+
+        if given_args['pod'] and given_args['pod_template_file']:
+            raise AirflowConfigException("Cannot pass both `pod` and `pod_template_file` arguments")
+        if args_without_defaults and (given_args['pod'] or given_args['pod_template_file']):
+            raise AirflowConfigException(
+                "Cannot configure pod and pass either `pod` or `pod_template_file`. Fields {} passed.".format(
+                    list(args_without_defaults.keys())
+                )
+            )
 
 
 def merge_objects(base_obj, client_obj):
diff --git a/airflow/kubernetes/worker_configuration.py b/airflow/kubernetes/worker_configuration.py
index 9c35910..0357f9f 100644
--- a/airflow/kubernetes/worker_configuration.py
+++ b/airflow/kubernetes/worker_configuration.py
@@ -28,7 +28,12 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 
 
 class WorkerConfiguration(LoggingMixin):
-    """Contains Kubernetes Airflow Worker configuration logic"""
+    """
+    Contains Kubernetes Airflow Worker configuration logic
+
+    :param kube_config: the kubernetes configuration from airflow.cfg
+    :type kube_config: airflow.executors.kubernetes_executor.KubeConfig
+    """
 
     dags_volume_name = 'airflow-dags'
     logs_volume_name = 'airflow-logs'
@@ -424,9 +429,12 @@ class WorkerConfiguration(LoggingMixin):
 
     def as_pod(self):
         """Creates POD."""
-        pod_generator = PodGenerator(
+        if self.kube_config.pod_template_file:
+            return PodGenerator(pod_template_file=self.kube_config.pod_template_file).gen_pod()
+
+        pod = PodGenerator(
             image=self.kube_config.kube_image,
-            image_pull_policy=self.kube_config.kube_image_pull_policy,
+            image_pull_policy=self.kube_config.kube_image_pull_policy or 'IfNotPresent',
             image_pull_secrets=self.kube_config.image_pull_secrets,
             volumes=self._get_volumes(),
             volume_mounts=self._get_volume_mounts(),
@@ -436,10 +444,10 @@ class WorkerConfiguration(LoggingMixin):
             tolerations=self.kube_config.kube_tolerations,
             envs=self._get_environment(),
             node_selectors=self.kube_config.kube_node_selectors,
-            service_account_name=self.kube_config.worker_service_account_name,
-        )
+            service_account_name=self.kube_config.worker_service_account_name or 'default',
+            restart_policy='Never'
+        ).gen_pod()
 
-        pod = pod_generator.gen_pod()
         pod.spec.containers[0].env_from = pod.spec.containers[0].env_from or []
         pod.spec.containers[0].env_from.extend(self._get_env_from())
         pod.spec.security_context = self._get_security_context()
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py
index e20324b..b6cecda 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -827,7 +827,55 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
 
     @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
     @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
-    @patch("airflow.kubernetes.kube_client.get_kube_client")
+    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
+    def test_pod_template_file(self, mock_client, monitor_mock, start_mock):
+        from airflow.utils.state import State
+        k = KubernetesPodOperator(
+            task_id='task',
+            pod_template_file='tests/kubernetes/pod.yaml',
+            do_xcom_push=True
+        )
+        monitor_mock.return_value = (State.SUCCESS, None)
+        context = self.create_context(k)
+        k.execute(context)
+        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
+        self.assertEqual({
+            'apiVersion': 'v1',
+            'kind': 'Pod',
+            'metadata': {'name': mock.ANY, 'namespace': 'mem-example'},
+            'spec': {
+                'volumes': [{'name': 'xcom', 'emptyDir': {}}],
+                'containers': [{
+                    'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
+                    'command': ['stress'],
+                    'image': 'polinux/stress',
+                    'name': 'memory-demo-ctr',
+                    'resources': {
+                        'limits': {'memory': '200Mi'},
+                        'requests': {'memory': '100Mi'}
+                    },
+                    'volumeMounts': [{
+                        'name': 'xcom',
+                        'mountPath': '/airflow/xcom'
+                    }]
+                }, {
+                    'name': 'airflow-xcom-sidecar',
+                    'image': "alpine",
+                    'command': ['sh', '-c', PodDefaults.XCOM_CMD],
+                    'volumeMounts': [
+                        {
+                            'name': 'xcom',
+                            'mountPath': '/airflow/xcom'
+                        }
+                    ],
+                    'resources': {'requests': {'cpu': '1m'}},
+                }],
+            }
+        }, actual_pod)
+
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
+    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
     def test_pod_priority_class_name(
             self,
             mock_client,
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index bf7bc56..3dabb78 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -33,8 +33,8 @@ try:
     from airflow.configuration import conf  # noqa: F401
     from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler, KubeConfig
     from airflow.executors.kubernetes_executor import KubernetesExecutor
-    from airflow.kubernetes import pod_generator
     from airflow.kubernetes.pod_generator import PodGenerator
+    from airflow.kubernetes import pod_generator
     from airflow.utils.state import State
 except ImportError:
     AirflowKubernetesScheduler = None  # type: ignore
diff --git a/tests/kubernetes/models/test_pod.py b/tests/kubernetes/models/test_pod.py
index b63af6d..45c32aa 100644
--- a/tests/kubernetes/models/test_pod.py
+++ b/tests/kubernetes/models/test_pod.py
@@ -59,7 +59,6 @@ class TestPod(unittest.TestCase):
                     'env': [],
                     'envFrom': [],
                     'image': 'airflow-worker:latest',
-                    'imagePullPolicy': 'IfNotPresent',
                     'name': 'base',
                     'ports': [{
                         'name': 'https',
@@ -72,7 +71,6 @@ class TestPod(unittest.TestCase):
                 }],
                 'hostNetwork': False,
                 'imagePullSecrets': [],
-                'restartPolicy': 'Never',
                 'volumes': []
             }
         }, result)
diff --git a/tests/kubernetes/models/test_secret.py b/tests/kubernetes/models/test_secret.py
index 44ab8b3..e91ff68 100644
--- a/tests/kubernetes/models/test_secret.py
+++ b/tests/kubernetes/models/test_secret.py
@@ -100,7 +100,6 @@ class TestSecret(unittest.TestCase):
                     }],
                     'envFrom': [{'secretRef': {'name': 'secret_a'}}],
                     'image': 'airflow-worker:latest',
-                    'imagePullPolicy': 'IfNotPresent',
                     'name': 'base',
                     'ports': [],
                     'volumeMounts': [{
@@ -110,7 +109,6 @@ class TestSecret(unittest.TestCase):
                 }],
                 'hostNetwork': False,
                 'imagePullSecrets': [],
-                'restartPolicy': 'Never',
                 'volumes': [{
                     'name': 'secretvol' + str(static_uuid),
                     'secret': {'secretName': 'secret_b'}
diff --git a/tests/kubernetes/pod.yaml b/tests/kubernetes/pod.yaml
new file mode 100644
index 0000000..cd419ed
--- /dev/null
+++ b/tests/kubernetes/pod.yaml
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+---
+apiVersion: v1
+kind: Pod
+metadata:
+  name: memory-demo
+  namespace: mem-example
+spec:
+  containers:
+    - name: memory-demo-ctr
+      image: polinux/stress
+      resources:
+        limits:
+          memory: "200Mi"
+        requests:
+          memory: "100Mi"
+      command: ["stress"]
+      args: ["--vm", "1", "--vm-bytes", "150M", "--vm-hang", "1"]
diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py
index ce15a8b..7d39cdc 100644
--- a/tests/kubernetes/test_pod_generator.py
+++ b/tests/kubernetes/test_pod_generator.py
@@ -21,6 +21,7 @@ import uuid
 import kubernetes.client.models as k8s
 from kubernetes.client import ApiClient
 
+from airflow.exceptions import AirflowConfigException
 from airflow.kubernetes.k8s_model import append_to_pod
 from airflow.kubernetes.pod import Resources
 from airflow.kubernetes.pod_generator import PodDefaults, PodGenerator, extend_object_field, merge_objects
@@ -31,6 +32,24 @@ class TestPodGenerator(unittest.TestCase):
 
     def setUp(self):
         self.static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48')
+        self.deserialize_result = {
+            'apiVersion': 'v1',
+            'kind': 'Pod',
+            'metadata': {'name': 'memory-demo', 'namespace': 'mem-example'},
+            'spec': {
+                'containers': [{
+                    'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
+                    'command': ['stress'],
+                    'image': 'polinux/stress',
+                    'name': 'memory-demo-ctr',
+                    'resources': {
+                        'limits': {'memory': '200Mi'},
+                        'requests': {'memory': '100Mi'}
+                    }
+                }]
+            }
+        }
+
         self.envs = {
             'ENVIRONMENT': 'prod',
             'LOG_LEVEL': 'warning'
@@ -77,7 +96,6 @@ class TestPodGenerator(unittest.TestCase):
                     'command': [
                         'sh', '-c', 'echo Hello Kubernetes!'
                     ],
-                    'imagePullPolicy': 'IfNotPresent',
                     'env': [{
                         'name': 'ENVIRONMENT',
                         'value': 'prod'
@@ -126,7 +144,6 @@ class TestPodGenerator(unittest.TestCase):
                         'readOnly': True
                     }]
                 }],
-                'restartPolicy': 'Never',
                 'volumes': [{
                     'name': 'secretvol' + str(self.static_uuid),
                     'secret': {
@@ -172,7 +189,7 @@ class TestPodGenerator(unittest.TestCase):
         result_dict['spec']['containers'][0]['envFrom'].sort(
             key=lambda x: list(x.values())[0]['name']
         )
-        self.assertDictEqual(result_dict, self.expected)
+        self.assertDictEqual(self.expected, result_dict)
 
     @mock.patch('uuid.uuid4')
     def test_gen_pod_extract_xcom(self, mock_uuid):
@@ -238,9 +255,6 @@ class TestPodGenerator(unittest.TestCase):
                         "name": "example-kubernetes-test-volume",
                     },
                 ],
-                "securityContext": {
-                    "runAsUser": 1000
-                }
             }
         })
         result = self.k8s_client.sanitize_for_serialization(result)
@@ -264,6 +278,7 @@ class TestPodGenerator(unittest.TestCase):
                         'name': 'example-kubernetes-test-volume'
                     }],
                 }],
+                'hostNetwork': False,
                 'imagePullSecrets': [],
                 'volumes': [{
                     'hostPath': {'path': '/tmp/'},
@@ -339,7 +354,7 @@ class TestPodGenerator(unittest.TestCase):
 
         result = PodGenerator.reconcile_pods(base_pod, mutator_pod)
         result = self.k8s_client.sanitize_for_serialization(result)
-        self.assertEqual(result, {
+        self.assertEqual({
             'apiVersion': 'v1',
             'kind': 'Pod',
             'metadata': {'name': 'name2-' + self.static_uuid.hex},
@@ -353,7 +368,6 @@ class TestPodGenerator(unittest.TestCase):
                     ],
                     'envFrom': [],
                     'image': 'image1',
-                    'imagePullPolicy': 'IfNotPresent',
                     'name': 'base',
                     'ports': [{
                         'containerPort': 2118,
@@ -369,7 +383,6 @@ class TestPodGenerator(unittest.TestCase):
                 }],
                 'hostNetwork': False,
                 'imagePullSecrets': [],
-                'restartPolicy': 'Never',
                 'volumes': [{
                     'hostPath': {'path': '/tmp/'},
                     'name': 'example-kubernetes-test-volume1'
@@ -378,7 +391,7 @@ class TestPodGenerator(unittest.TestCase):
                     'name': 'example-kubernetes-test-volume2'
                 }]
             }
-        })
+        }, result)
 
     @mock.patch('uuid.uuid4')
     def test_construct_pod_empty_worker_config(self, mock_uuid):
@@ -424,7 +437,6 @@ class TestPodGenerator(unittest.TestCase):
                     'command': ['command'],
                     'env': [],
                     'envFrom': [],
-                    'imagePullPolicy': 'IfNotPresent',
                     'name': 'base',
                     'ports': [],
                     'resources': {
@@ -437,7 +449,6 @@ class TestPodGenerator(unittest.TestCase):
                 }],
                 'hostNetwork': False,
                 'imagePullSecrets': [],
-                'restartPolicy': 'Never',
                 'volumes': []
             }
         }, sanitized_result)
@@ -486,7 +497,6 @@ class TestPodGenerator(unittest.TestCase):
                     'command': ['command'],
                     'env': [],
                     'envFrom': [],
-                    'imagePullPolicy': 'IfNotPresent',
                     'name': 'base',
                     'ports': [],
                     'resources': {
@@ -499,7 +509,6 @@ class TestPodGenerator(unittest.TestCase):
                 }],
                 'hostNetwork': False,
                 'imagePullSecrets': [],
-                'restartPolicy': 'Never',
                 'volumes': []
             }
         }, sanitized_result)
@@ -573,7 +582,6 @@ class TestPodGenerator(unittest.TestCase):
                     'command': ['command'],
                     'env': [],
                     'envFrom': [],
-                    'imagePullPolicy': 'IfNotPresent',
                     'name': 'base',
                     'ports': [],
                     'resources': {
@@ -587,7 +595,6 @@ class TestPodGenerator(unittest.TestCase):
                 }],
                 'hostNetwork': False,
                 'imagePullSecrets': [],
-                'restartPolicy': 'Never',
                 'volumes': []
             }
         }, sanitized_result)
@@ -721,3 +728,44 @@ class TestPodGenerator(unittest.TestCase):
         client_spec.containers = [k8s.V1Container(name='client_container1', image='base_image')]
         client_spec.active_deadline_seconds = 100
         self.assertEqual(client_spec, res)
+
+    def test_deserialize_model_file(self):
+        fixture = 'tests/kubernetes/pod.yaml'
+        result = PodGenerator.deserialize_model_file(fixture)
+        sanitized_res = self.k8s_client.sanitize_for_serialization(result)
+        self.assertEqual(sanitized_res, self.deserialize_result)
+
+    def test_deserialize_model_string(self):
+        fixture = """
+apiVersion: v1
+kind: Pod
+metadata:
+  name: memory-demo
+  namespace: mem-example
+spec:
+  containers:
+    - name: memory-demo-ctr
+      image: polinux/stress
+      resources:
+        limits:
+          memory: "200Mi"
+        requests:
+          memory: "100Mi"
+      command: ["stress"]
+      args: ["--vm", "1", "--vm-bytes", "150M", "--vm-hang", "1"]
+        """
+        result = PodGenerator.deserialize_model_file(fixture)
+        sanitized_res = self.k8s_client.sanitize_for_serialization(result)
+        self.assertEqual(sanitized_res, self.deserialize_result)
+
+    def test_validate_pod_generator(self):
+        with self.assertRaises(AirflowConfigException):
+            PodGenerator(image='k', pod=k8s.V1Pod())
+        with self.assertRaises(AirflowConfigException):
+            PodGenerator(pod=k8s.V1Pod(), pod_template_file='k')
+        with self.assertRaises(AirflowConfigException):
+            PodGenerator(image='k', pod_template_file='k')
+
+        PodGenerator(image='k')
+        PodGenerator(pod_template_file='tests/kubernetes/pod.yaml')
+        PodGenerator(pod=k8s.V1Pod())
diff --git a/tests/kubernetes/test_worker_configuration.py b/tests/kubernetes/test_worker_configuration.py
index 0730595..b6db2b5 100644
--- a/tests/kubernetes/test_worker_configuration.py
+++ b/tests/kubernetes/test_worker_configuration.py
@@ -17,7 +17,6 @@
 #
 
 import unittest
-
 import six
 
 from tests.compat import mock
@@ -94,6 +93,9 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
         self.kube_config.dags_folder = None
         self.kube_config.git_dags_folder_mount_point = None
         self.kube_config.kube_labels = {'dag_id': 'original_dag_id', 'my_label': 'label_id'}
+        self.kube_config.pod_template_file = ''
+        self.kube_config.restart_policy = ''
+        self.kube_config.image_pull_policy = ''
         self.api_client = ApiClient()
 
     @conf_vars({
@@ -358,7 +360,6 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
             worker_config.as_pod(),
             "default",
             "sample-uuid",
-
         )
         expected_labels = {
             'airflow-worker': 'sample-uuid',
@@ -672,6 +673,15 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase):
             k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name='secretref_b'))
         ], configmaps)
 
+    def test_pod_template_file(self):
+        fixture = 'tests/kubernetes/pod.yaml'
+        self.kube_config.pod_template_file = fixture
+        worker_config = WorkerConfiguration(self.kube_config)
+        result = worker_config.as_pod()
+        expected = PodGenerator.deserialize_model_file(fixture)
+        expected.metadata.name = mock.ANY
+        self.assertEqual(expected, result)
+
     def test_get_labels(self):
         worker_config = WorkerConfiguration(self.kube_config)
         labels = worker_config._get_labels({'my_kube_executor_label': 'kubernetes'}, {