You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2021/04/30 01:42:09 UTC

[airflow] branch master updated: Fix parallelism after KubeExecutor pod adoption (#15555)

This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c296c9  Fix parallelism after KubeExecutor pod adoption (#15555)
9c296c9 is described below

commit 9c296c93967692133470626adb007941754b2f96
Author: Jed Cunningham <66...@users.noreply.github.com>
AuthorDate: Thu Apr 29 19:41:51 2021 -0600

    Fix parallelism after KubeExecutor pod adoption (#15555)
    
    * Fix parallelism after KubeExecutor pod adoption
    
    This fixes a bug where adopted pods didn't count as being run by the
    executor for the purposes of honoring parallelism, nor for metrics being
    exported. Now adopted tasks will be reflected like a task the executor
    actually started.
    
    * Fix circular import
---
 airflow/executors/kubernetes_executor.py          | 66 ++++++++---------------
 airflow/kubernetes/kubernetes_helper_functions.py | 20 +++++++
 airflow/models/taskinstance.py                    |  3 +-
 tests/executors/test_kubernetes_executor.py       | 55 +++++++++++++------
 4 files changed, 84 insertions(+), 60 deletions(-)

diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index a2ae701..5c42acd 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -30,11 +30,8 @@ from datetime import timedelta
 from queue import Empty, Queue  # pylint: disable=unused-import
 from typing import Any, Dict, List, Optional, Tuple
 
-import kubernetes
-from dateutil import parser
 from kubernetes import client, watch
 from kubernetes.client import Configuration, models as k8s
-from kubernetes.client.models import V1Pod
 from kubernetes.client.rest import ApiException
 from urllib3.exceptions import ReadTimeoutError
 
@@ -43,10 +40,9 @@ from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, C
 from airflow.kubernetes import pod_generator
 from airflow.kubernetes.kube_client import get_kube_client
 from airflow.kubernetes.kube_config import KubeConfig
-from airflow.kubernetes.kubernetes_helper_functions import create_pod_id
+from airflow.kubernetes.kubernetes_helper_functions import annotations_to_key, create_pod_id
 from airflow.kubernetes.pod_generator import PodGenerator
-from airflow.models import TaskInstance
-from airflow.models.taskinstance import TaskInstanceKey
+from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
 from airflow.settings import pod_mutation_hook
 from airflow.utils import timezone
 from airflow.utils.event_scheduler import EventScheduler
@@ -246,7 +242,7 @@ class AirflowKubernetesScheduler(LoggingMixin):
         self.scheduler_job_id = scheduler_job_id
         self.kube_watcher = self._make_kube_watcher()
 
-    def run_pod_async(self, pod: V1Pod, **kwargs):
+    def run_pod_async(self, pod: k8s.V1Pod, **kwargs):
         """Runs POD asynchronously"""
         pod_mutation_hook(pod)
 
@@ -372,20 +368,11 @@ class AirflowKubernetesScheduler(LoggingMixin):
         self.log.info(
             'Attempting to finish pod; pod_id: %s; state: %s; annotations: %s', pod_id, state, annotations
         )
-        key = self._annotations_to_key(annotations=annotations)
+        key = annotations_to_key(annotations=annotations)
         if key:
             self.log.debug('finishing job %s - %s (%s)', key, state, pod_id)
             self.result_queue.put((key, state, pod_id, namespace, resource_version))
 
-    def _annotations_to_key(self, annotations: Dict[str, str]) -> Optional[TaskInstanceKey]:
-        self.log.debug("Creating task key for annotations %s", annotations)
-        dag_id = annotations['dag_id']
-        task_id = annotations['task_id']
-        try_number = int(annotations['try_number'])
-        execution_date = parser.parse(annotations['execution_date'])
-
-        return TaskInstanceKey(dag_id, task_id, execution_date, try_number)
-
     def _flush_watcher_queue(self) -> None:
         self.log.debug('Executor shutting down, watcher_queue approx. size=%d', self.watcher_queue.qsize())
         while True:
@@ -653,14 +640,7 @@ class KubernetesExecutor(BaseExecutor, LoggingMixin):
     def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
         tis_to_flush = [ti for ti in tis if not ti.queued_by_job_id]
         scheduler_job_ids = {ti.queued_by_job_id for ti in tis}
-        pod_ids = {
-            create_pod_id(
-                dag_id=pod_generator.make_safe_label_value(ti.dag_id),
-                task_id=pod_generator.make_safe_label_value(ti.task_id),
-            ): ti
-            for ti in tis
-            if ti.queued_by_job_id
-        }
+        pod_ids = {ti.key: ti for ti in tis if ti.queued_by_job_id}
         kube_client: client.CoreV1Api = self.kube_client
         for scheduler_job_id in scheduler_job_ids:
             scheduler_job_id = pod_generator.make_safe_label_value(str(scheduler_job_id))
@@ -672,7 +652,9 @@ class KubernetesExecutor(BaseExecutor, LoggingMixin):
         tis_to_flush.extend(pod_ids.values())
         return tis_to_flush
 
-    def adopt_launched_task(self, kube_client, pod, pod_ids: dict):
+    def adopt_launched_task(
+        self, kube_client: client.CoreV1Api, pod: k8s.V1Pod, pod_ids: Dict[TaskInstanceKey, k8s.V1Pod]
+    ) -> None:
         """
         Patch existing pod so that the current KubernetesJobWatcher can monitor it via label selectors
 
@@ -684,27 +666,23 @@ class KubernetesExecutor(BaseExecutor, LoggingMixin):
         pod.metadata.labels['airflow-worker'] = pod_generator.make_safe_label_value(
             str(self.scheduler_job_id)
         )
-        dag_id = pod.metadata.labels['dag_id']
-        task_id = pod.metadata.labels['task_id']
-        pod_id = create_pod_id(dag_id=dag_id, task_id=task_id)
+        pod_id = annotations_to_key(pod.metadata.annotations)
         if pod_id not in pod_ids:
-            self.log.error(
-                "attempting to adopt task %s in dag %s which was not specified by database",
-                task_id,
-                dag_id,
+            self.log.error("attempting to adopt taskinstance which was not specified by database: %s", pod_id)
+            return
+
+        try:
+            kube_client.patch_namespaced_pod(
+                name=pod.metadata.name,
+                namespace=pod.metadata.namespace,
+                body=PodGenerator.serialize_pod(pod),
             )
-        else:
-            try:
-                kube_client.patch_namespaced_pod(
-                    name=pod.metadata.name,
-                    namespace=pod.metadata.namespace,
-                    body=PodGenerator.serialize_pod(pod),
-                )
-                pod_ids.pop(pod_id)
-            except ApiException as e:
-                self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)
+            pod_ids.pop(pod_id)
+            self.running.add(pod_id)
+        except ApiException as e:
+            self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)
 
-    def _adopt_completed_pods(self, kube_client: kubernetes.client.CoreV1Api):
+    def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None:
         """
 
         Patch completed pod so that the KubernetesJobWatcher can delete it.
diff --git a/airflow/kubernetes/kubernetes_helper_functions.py b/airflow/kubernetes/kubernetes_helper_functions.py
index a8bbd8a..da58a38 100644
--- a/airflow/kubernetes/kubernetes_helper_functions.py
+++ b/airflow/kubernetes/kubernetes_helper_functions.py
@@ -15,6 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import logging
+from typing import Dict, Optional
+
+from dateutil import parser
+
+from airflow.models.taskinstance import TaskInstanceKey
+
+log = logging.getLogger(__name__)
+
 
 def _strip_unsafe_kubernetes_special_chars(string: str) -> str:
     """
@@ -44,3 +53,14 @@ def create_pod_id(dag_id: str, task_id: str) -> str:
     safe_dag_id = _strip_unsafe_kubernetes_special_chars(dag_id)
     safe_task_id = _strip_unsafe_kubernetes_special_chars(task_id)
     return safe_dag_id + safe_task_id
+
+
+def annotations_to_key(annotations: Dict[str, str]) -> Optional[TaskInstanceKey]:
+    """Build a TaskInstanceKey based on pod annotations"""
+    log.debug("Creating task key for annotations %s", annotations)
+    dag_id = annotations['dag_id']
+    task_id = annotations['task_id']
+    try_number = int(annotations['try_number'])
+    execution_date = parser.parse(annotations['execution_date'])
+
+    return TaskInstanceKey(dag_id, task_id, execution_date, try_number)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c5a555b..d51a5af 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -77,7 +77,6 @@ try:
     from kubernetes.client.api_client import ApiClient
 
     from airflow.kubernetes.kube_config import KubeConfig
-    from airflow.kubernetes.kubernetes_helper_functions import create_pod_id
     from airflow.kubernetes.pod_generator import PodGenerator
 except ImportError:
     ApiClient = None
@@ -1776,6 +1775,8 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
 
     def render_k8s_pod_yaml(self) -> Optional[dict]:
         """Render k8s pod yaml"""
+        from airflow.kubernetes.kubernetes_helper_functions import create_pod_id  # Circular import
+
         kube_config = KubeConfig()
         pod = PodGenerator.construct_pod(
             dag_id=self.dag_id,
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index bd6b2cc..e449b96 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -41,6 +41,7 @@ try:
         get_base_pod_from_template,
     )
     from airflow.kubernetes import pod_generator
+    from airflow.kubernetes.kubernetes_helper_functions import annotations_to_key
     from airflow.kubernetes.pod_generator import PodGenerator, datetime_to_label_safe_datestring
     from airflow.utils.state import State
 except ImportError:
@@ -380,22 +381,28 @@ class TestKubernetesExecutor(unittest.TestCase):
     def test_try_adopt_task_instances(self, mock_adopt_completed_pods, mock_adopt_launched_task):
         executor = self.kubernetes_executor
         executor.scheduler_job_id = "10"
-        mock_ti = mock.MagicMock(queued_by_job_id="1", external_executor_id="1", dag_id="dag", task_id="task")
-        pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="foo", labels={"dag_id": "dag", "task_id": "task"}))
-        pod_id = create_pod_id(dag_id="dag", task_id="task")
+        ti_key = annotations_to_key(
+            {
+                'dag_id': 'dag',
+                'execution_date': datetime.utcnow().isoformat(),
+                'task_id': 'task',
+                'try_number': '1',
+            }
+        )
+        mock_ti = mock.MagicMock(queued_by_job_id="1", external_executor_id="1", key=ti_key)
+        pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="foo"))
         mock_kube_client = mock.MagicMock()
         mock_kube_client.list_namespaced_pod.return_value.items = [pod]
         executor.kube_client = mock_kube_client
 
         # First adoption
-        executor.try_adopt_task_instances([mock_ti])
+        reset_tis = executor.try_adopt_task_instances([mock_ti])
         mock_kube_client.list_namespaced_pod.assert_called_once_with(
             namespace='default', label_selector='airflow-worker=1'
         )
-        mock_adopt_launched_task.assert_called_once_with(mock_kube_client, pod, {pod_id: mock_ti})
+        mock_adopt_launched_task.assert_called_once_with(mock_kube_client, pod, {ti_key: mock_ti})
         mock_adopt_completed_pods.assert_called_once()
-        # We aren't checking the return value of `try_adopt_task_instances` because it relies on
-        # `adopt_launched_task` mutating its arg. This should be refactored, but not right now.
+        assert reset_tis == [mock_ti]  # assume failure adopting when checking return
 
         # Second adoption (queued_by_job_id and external_executor_id no longer match)
         mock_kube_client.reset_mock()
@@ -404,13 +411,16 @@ class TestKubernetesExecutor(unittest.TestCase):
 
         mock_ti.queued_by_job_id = "10"  # scheduler_job would have updated this after the first adoption
         executor.scheduler_job_id = "20"
+        # assume success adopting when checking return, `adopt_launched_task` pops `ti_key` from `pod_ids`
+        mock_adopt_launched_task.side_effect = lambda client, pod, pod_ids: pod_ids.pop(ti_key)
 
-        executor.try_adopt_task_instances([mock_ti])
+        reset_tis = executor.try_adopt_task_instances([mock_ti])
         mock_kube_client.list_namespaced_pod.assert_called_once_with(
             namespace='default', label_selector='airflow-worker=10'
         )
-        mock_adopt_launched_task.assert_called_once_with(mock_kube_client, pod, {pod_id: mock_ti})
+        mock_adopt_launched_task.assert_called_once()  # Won't check args this time around as they get mutated
         mock_adopt_completed_pods.assert_called_once()
+        assert reset_tis == []  # This time our return is empty - no TIs to reset
 
     @mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor._adopt_completed_pods')
     def test_try_adopt_task_instances_multiple_scheduler_ids(self, mock_adopt_completed_pods):
@@ -455,17 +465,24 @@ class TestKubernetesExecutor(unittest.TestCase):
     def test_adopt_launched_task(self, mock_kube_client):
         executor = self.kubernetes_executor
         executor.scheduler_job_id = "modified"
-        pod_ids = {"dagtask": {}}
+        annotations = {
+            'dag_id': 'dag',
+            'execution_date': datetime.utcnow().isoformat(),
+            'task_id': 'task',
+            'try_number': '1',
+        }
+        ti_key = annotations_to_key(annotations)
         pod = k8s.V1Pod(
-            metadata=k8s.V1ObjectMeta(
-                name="foo", labels={"airflow-worker": "bar", "dag_id": "dag", "task_id": "task"}
-            )
+            metadata=k8s.V1ObjectMeta(name="foo", labels={"airflow-worker": "bar"}, annotations=annotations)
         )
+        pod_ids = {ti_key: {}}
+
         executor.adopt_launched_task(mock_kube_client, pod=pod, pod_ids=pod_ids)
         assert mock_kube_client.patch_namespaced_pod.call_args[1] == {
             'body': {
                 'metadata': {
-                    'labels': {'airflow-worker': 'modified', 'dag_id': 'dag', 'task_id': 'task'},
+                    'labels': {'airflow-worker': 'modified'},
+                    'annotations': annotations,
                     'name': 'foo',
                 }
             },
@@ -473,6 +490,7 @@ class TestKubernetesExecutor(unittest.TestCase):
             'namespace': None,
         }
         assert pod_ids == {}
+        assert executor.running == {ti_key}
 
     @mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
     def test_not_adopt_unassigned_task(self, mock_kube_client):
@@ -486,7 +504,14 @@ class TestKubernetesExecutor(unittest.TestCase):
         pod_ids = {"foobar": {}}
         pod = k8s.V1Pod(
             metadata=k8s.V1ObjectMeta(
-                name="foo", labels={"airflow-worker": "bar", "dag_id": "dag", "task_id": "task"}
+                name="foo",
+                labels={"airflow-worker": "bar"},
+                annotations={
+                    'dag_id': 'dag',
+                    'execution_date': datetime.utcnow().isoformat(),
+                    'task_id': 'task',
+                    'try_number': '1',
+                },
             )
         )
         executor.adopt_launched_task(mock_kube_client, pod=pod, pod_ids=pod_ids)