You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/11/05 22:53:00 UTC

[airflow] branch master updated: Add ability to specify pod_template_file in executor_config (#11784)

This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 68ba54b  Add ability to specify pod_template_file in executor_config (#11784)
68ba54b is described below

commit 68ba54bbd5a275fba1a126f8e67bd69e5cf4b362
Author: Daniel Imberman <da...@gmail.com>
AuthorDate: Thu Nov 5 14:48:05 2020 -0800

    Add ability to specify pod_template_file in executor_config (#11784)
    
    * Add pod_template_override to executor_config
    
    Users will be able to override the base pod_template_file on a per-task
    basis.
    
    * change docstring
    
    * fix doc
    
    * fix static checks
    
    * add description
---
 .../example_kubernetes_executor_config.py          | 12 ++++++++
 airflow/executors/kubernetes_executor.py           | 34 +++++++++++++++++-----
 chart/requirements.lock                            |  6 ++--
 docs/executor/kubernetes.rst                       | 10 ++++++-
 tests/executors/test_kubernetes_executor.py        | 14 +++++++++
 tests/www/test_views.py                            |  5 +++-
 6 files changed, 69 insertions(+), 12 deletions(-)

diff --git a/airflow/example_dags/example_kubernetes_executor_config.py b/airflow/example_dags/example_kubernetes_executor_config.py
index 57b2c4a..5fbc575 100644
--- a/airflow/example_dags/example_kubernetes_executor_config.py
+++ b/airflow/example_dags/example_kubernetes_executor_config.py
@@ -99,6 +99,17 @@ with DAG(
     )
     # [END task_with_volume]
 
+    # [START task_with_template]
+    task_with_template = PythonOperator(
+        task_id="task_with_template",
+        python_callable=print_stuff,
+        executor_config={
+            "pod_template_file": "/usr/local/airflow/pod_templates/basic_template.yaml",
+            "pod_override": k8s.V1Pod(metadata=k8s.V1ObjectMeta(labels={"release": "stable"})),
+        },
+    )
+    # [END task_with_template]
+
     # [START task_with_sidecar]
     sidecar_task = PythonOperator(
         task_id="task_with_sidecar",
@@ -146,3 +157,4 @@ with DAG(
     start_task >> volume_task >> third_task
     start_task >> other_ns_task
     start_task >> sidecar_task
+    start_task >> task_with_template
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index ad200e1..b7071a0 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -32,7 +32,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 import kubernetes
 from dateutil import parser
 from kubernetes import client, watch
-from kubernetes.client import Configuration
+from kubernetes.client import Configuration, models as k8s
 from kubernetes.client.rest import ApiException
 from urllib3.exceptions import ReadTimeoutError
 
@@ -50,8 +50,8 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import provide_session
 from airflow.utils.state import State
 
-# TaskInstance key, command, configuration
-KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any]
+# TaskInstance key, command, configuration, pod_template_file
+KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]
 
 # key, state, pod_id, namespace, resource_version
 KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str]
@@ -341,13 +341,14 @@ class AirflowKubernetesScheduler(LoggingMixin):
         status
         """
         self.log.info('Kubernetes job is %s', str(next_job))
-        key, command, kube_executor_config = next_job
+        key, command, kube_executor_config, pod_template_file = next_job
         dag_id, task_id, execution_date, try_number = key
 
         if command[0:3] != ["airflow", "tasks", "run"]:
             raise ValueError('The command must start with ["airflow", "tasks", "run"].')
 
-        base_worker_pod = PodGenerator.deserialize_model_file(self.kube_config.pod_template_file)
+        base_worker_pod = get_base_pod_from_template(pod_template_file, self.kube_config)
+
         if not base_worker_pod:
             raise AirflowException(
                 f"could not find a valid worker template yaml at {self.kube_config.pod_template_file}"
@@ -505,6 +506,21 @@ def create_pod_id(dag_id: str, task_id: str) -> str:
     return safe_dag_id + safe_task_id
 
 
+def get_base_pod_from_template(pod_template_file: Optional[str], kube_config: Any) -> k8s.V1Pod:
+    """
+    Reads either the pod_template_file set in the executor_config or the base pod_template_file
+    set in the airflow.cfg to craft a "base pod" that will be used by the KubernetesExecutor
+
+    :param pod_template_file: absolute path to a pod_template_file.yaml or None
+    :param kube_config: The KubeConfig class generated by airflow that contains all kube metadata
+    :return: a V1Pod that can be used as the base pod for k8s tasks
+    """
+    if pod_template_file:
+        return PodGenerator.deserialize_model_file(pod_template_file)
+    else:
+        return PodGenerator.deserialize_model_file(kube_config.pod_template_file)
+
+
 class KubernetesExecutor(BaseExecutor, LoggingMixin):
     """Executor for Kubernetes"""
 
@@ -619,10 +635,14 @@ class KubernetesExecutor(BaseExecutor, LoggingMixin):
         """Executes task asynchronously"""
         self.log.info('Add task %s with command %s with executor_config %s', key, command, executor_config)
         kube_executor_config = PodGenerator.from_obj(executor_config)
+        if executor_config:
+            pod_template_file = executor_config.get("pod_template_override", None)
+        else:
+            pod_template_file = None
         if not self.task_queue:
             raise AirflowException(NOT_STARTED_MESSAGE)
         self.event_buffer[key] = (State.QUEUED, self.scheduler_job_id)
-        self.task_queue.put((key, command, kube_executor_config))
+        self.task_queue.put((key, command, kube_executor_config, pod_template_file))
 
     def sync(self) -> None:
         """Synchronize task state."""
@@ -677,7 +697,7 @@ class KubernetesExecutor(BaseExecutor, LoggingMixin):
                 except ApiException as e:
                     if e.reason == "BadRequest":
                         self.log.error("Request was invalid. Failing task")
-                        key, _, _ = task
+                        key, _, _, _ = task
                         self.change_state(key, State.FAILED, e)
                     else:
                         self.log.warning(
diff --git a/chart/requirements.lock b/chart/requirements.lock
index 715458e..eb62c80 100644
--- a/chart/requirements.lock
+++ b/chart/requirements.lock
@@ -1,6 +1,6 @@
 dependencies:
 - name: postgresql
-  repository: https://kubernetes-charts.storage.googleapis.com/
+  repository: https://kubernetes-charts.storage.googleapis.com
   version: 6.3.12
-digest: sha256:e8d53453861c590e6ae176331634c9268a11cf894be17ed580fa2b347101be97
-generated: "2020-10-27T21:16:13.0063538Z"
+digest: sha256:58d88cf56e78b2380091e9e16cc6ccf58b88b3abe4a1886dd47cd9faef5309af
+generated: "2020-11-04T15:59:36.967913-08:00"
diff --git a/docs/executor/kubernetes.rst b/docs/executor/kubernetes.rst
index f3940d8..c016726 100644
--- a/docs/executor/kubernetes.rst
+++ b/docs/executor/kubernetes.rst
@@ -123,7 +123,15 @@ name ``base`` and a second container containing your desired sidecar.
     :start-after: [START task_with_sidecar]
     :end-before: [END task_with_sidecar]
 
-In the following example, we create a sidecar container that shares a volume_mount for data sharing.
+You can also create custom ``pod_template_file`` on a per-task basis so that you can recycle the same base values between multiple tasks.
+This will replace the default ``pod_template_file`` named in the airflow.cfg and then override that template using the ``pod_override_spec``.
+
+Here is an example of a task with both features:
+
+.. exampleinclude:: /../airflow/example_dags/example_kubernetes_executor_config.py
+    :language: python
+    :start-after: [START task_with_template]
+    :end-before: [END task_with_template]
 
 KubernetesExecutor Architecture
 ################################
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index 9765669..9d8d72f 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -35,6 +35,7 @@ try:
         AirflowKubernetesScheduler,
         KubernetesExecutor,
         create_pod_id,
+        get_base_pod_from_template,
     )
     from airflow.kubernetes import pod_generator
     from airflow.kubernetes.pod_generator import PodGenerator
@@ -84,6 +85,19 @@ class TestAirflowKubernetesScheduler(unittest.TestCase):
             pod_name = PodGenerator.make_unique_pod_id(create_pod_id(dag_id, task_id))
             self.assertTrue(self._is_valid_pod_id(pod_name))
 
+    @unittest.skipIf(AirflowKubernetesScheduler is None, 'kubernetes python package is not installed')
+    @mock.patch("airflow.kubernetes.pod_generator.PodGenerator")
+    @mock.patch("airflow.executors.kubernetes_executor.KubeConfig")
+    def test_get_base_pod_from_template(self, mock_kubeconfig, mock_generator):
+        pod_template_file_path = "/bar/biz"
+        get_base_pod_from_template(pod_template_file_path, None)
+        self.assertEqual("deserialize_model_dict", mock_generator.mock_calls[0][0])
+        self.assertEqual(pod_template_file_path, mock_generator.mock_calls[0][1][0])
+        mock_kubeconfig.pod_template_file = "/foo/bar"
+        get_base_pod_from_template(None, mock_kubeconfig)
+        self.assertEqual("deserialize_model_dict", mock_generator.mock_calls[1][0])
+        self.assertEqual("/foo/bar", mock_generator.mock_calls[1][1][0])
+
     def test_make_safe_label_value(self):
         for dag_id, task_id in self._cases():
             safe_dag_id = pod_generator.make_safe_label_value(dag_id)
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index e171805..901e081 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -1164,7 +1164,10 @@ class TestRedocView(TestBase):
 
         self.assertEqual(len(templates), 1)
         self.assertEqual(templates[0].name, 'airflow/redoc.html')
-        self.assertEqual(templates[0].local_context, {'openapi_spec_url': '/api/v1/openapi.yaml'})
+        self.assertEqual(
+            templates[0].local_context,
+            {'openapi_spec_url': '/api/v1/openapi.yaml'},
+        )
 
 
 class TestLogView(TestBase):