You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/06/27 20:55:59 UTC

[airflow] branch main updated: fix spark-kubernetes-operator compatibality (#31798)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6693bdd72d fix spark-kubernetes-operator compatibality (#31798)
6693bdd72d is described below

commit 6693bdd72d70989f4400b5807e2945d814a83b85
Author: Hossein Torabi <ho...@booking.com>
AuthorDate: Tue Jun 27 22:55:51 2023 +0200

    fix spark-kubernetes-operator compatibality (#31798)
---
 .../cncf/kubernetes/operators/spark_kubernetes.py  | 89 +++++++++++++++-------
 .../kubernetes/operators/test_spark_kubernetes.py  | 84 +++++++++++++++++---
 2 files changed, 136 insertions(+), 37 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index 2206e04a73..ccc13469fa 100644
--- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -17,10 +17,12 @@
 # under the License.
 from __future__ import annotations
 
+import datetime
 from typing import TYPE_CHECKING, Sequence
 
 from kubernetes.watch import Watch
 
+from airflow import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook, _load_body_to_dict
 
@@ -43,6 +45,7 @@ class SparkKubernetesOperator(BaseOperator):
         for the to Kubernetes cluster.
     :param api_group: kubernetes api group of sparkApplication
     :param api_version: kubernetes api version of sparkApplication
+    :param watch: whether to watch the job status and logs or not
     """
 
     template_fields: Sequence[str] = ("application_file", "namespace")
@@ -60,6 +63,7 @@ class SparkKubernetesOperator(BaseOperator):
         in_cluster: bool | None = None,
         cluster_context: str | None = None,
         config_file: str | None = None,
+        watch: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -72,6 +76,7 @@ class SparkKubernetesOperator(BaseOperator):
         self.in_cluster = in_cluster
         self.cluster_context = cluster_context
         self.config_file = config_file
+        self.watch = watch
 
         self.hook = KubernetesHook(
             conn_id=self.kubernetes_conn_id,
@@ -84,35 +89,67 @@ class SparkKubernetesOperator(BaseOperator):
         body = _load_body_to_dict(self.application_file)
         name = body["metadata"]["name"]
         namespace = self.namespace or self.hook.get_namespace()
-        namespace_event_stream = Watch().stream(
-            self.hook.core_v1_client.list_namespaced_pod,
-            namespace=namespace,
-            _preload_content=False,
-            watch=True,
-            label_selector=f"sparkoperator.k8s.io/app-name={name},spark-role=driver",
-            field_selector="status.phase=Running",
-        )
 
-        self.hook.create_custom_object(
-            group=self.api_group,
-            version=self.api_version,
-            plural=self.plural,
-            body=body,
-            namespace=namespace,
-        )
-        for event in namespace_event_stream:
-            if event["type"] == "ADDED":
-                pod_log_stream = Watch().stream(
-                    self.hook.core_v1_client.read_namespaced_pod_log,
-                    name=f"{name}-driver",
+        response = None
+        is_job_created = False
+        if self.watch:
+            try:
+                namespace_event_stream = Watch().stream(
+                    self.hook.core_v1_client.list_namespaced_event,
+                    namespace=namespace,
+                    watch=True,
+                    field_selector=f"involvedObject.kind=SparkApplication,involvedObject.name={name}",
+                )
+
+                response = self.hook.create_custom_object(
+                    group=self.api_group,
+                    version=self.api_version,
+                    plural=self.plural,
+                    body=body,
                     namespace=namespace,
-                    _preload_content=False,
-                    timestamps=True,
                 )
-                for line in pod_log_stream:
-                    self.log.info(line)
-            else:
-                break
+                is_job_created = True
+                for event in namespace_event_stream:
+                    obj = event["object"]
+                    if event["object"].last_timestamp >= datetime.datetime.strptime(
+                        response["metadata"]["creationTimestamp"], "%Y-%m-%dT%H:%M:%S%z"
+                    ):
+                        self.log.info(obj.message)
+                        if obj.reason == "SparkDriverRunning":
+                            pod_log_stream = Watch().stream(
+                                self.hook.core_v1_client.read_namespaced_pod_log,
+                                name=f"{name}-driver",
+                                namespace=namespace,
+                                timestamps=True,
+                            )
+                            for line in pod_log_stream:
+                                self.log.info(line)
+                        elif obj.reason in [
+                            "SparkApplicationSubmissionFailed",
+                            "SparkApplicationFailed",
+                            "SparkApplicationDeleted",
+                        ]:
+                            is_job_created = False
+                            raise AirflowException(obj.message)
+                        elif obj.reason == "SparkApplicationCompleted":
+                            break
+                        else:
+                            continue
+            except Exception:
+                if is_job_created:
+                    self.on_kill()
+                raise
+
+        else:
+            response = self.hook.create_custom_object(
+                group=self.api_group,
+                version=self.api_version,
+                plural=self.plural,
+                body=body,
+                namespace=namespace,
+            )
+
+        return response
 
     def on_kill(self) -> None:
         body = _load_body_to_dict(self.application_file)
diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
index 7ea67aec80..c7f30dad24 100644
--- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -16,8 +16,13 @@
 # under the License.
 from __future__ import annotations
 
-from unittest.mock import patch
+from datetime import datetime
+from unittest.mock import MagicMock, patch
 
+import pytest
+from dateutil import tz
+
+from airflow import AirflowException
 from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator
 
 
@@ -43,13 +48,22 @@ def test_spark_kubernetes_operator(mock_kubernetes_hook):
 @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.Watch.stream")
 @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
 @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
-def test_execute(mock_kubernetes_hook, mock_load_body_to_dict, mock_stream):
+def test_execute_with_watch(mock_kubernetes_hook, mock_load_body_to_dict, mock_stream):
     mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}}
+
+    mock_kubernetes_hook.return_value.create_custom_object.return_value = {
+        "metadata": {"name": "spark-app", "creationTimestamp": "2022-01-01T00:00:00Z"}
+    }
     mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
-    mock_stream.side_effect = [[{"type": "ADDED"}], []]
 
-    op = SparkKubernetesOperator(task_id="task_id", application_file="application_file")
-    op.execute({})
+    object_mock = MagicMock()
+    object_mock.reason = "SparkDriverRunning"
+    object_mock.last_timestamp = datetime(2022, 1, 1, 23, 59, 59, tzinfo=tz.tzutc())
+    mock_stream.side_effect = [[{"object": object_mock}], []]
+
+    op = SparkKubernetesOperator(task_id="task_id", application_file="application_file", watch=True)
+    operator_output = op.execute({})
+
     mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with(
         group="sparkoperator.k8s.io",
         version="v1beta2",
@@ -60,22 +74,70 @@ def test_execute(mock_kubernetes_hook, mock_load_body_to_dict, mock_stream):
 
     assert mock_stream.call_count == 2
     mock_stream.assert_any_call(
-        mock_kubernetes_hook.return_value.core_v1_client.list_namespaced_pod,
+        mock_kubernetes_hook.return_value.core_v1_client.list_namespaced_event,
         namespace="default",
-        _preload_content=False,
         watch=True,
-        label_selector="sparkoperator.k8s.io/app-name=spark-app,spark-role=driver",
-        field_selector="status.phase=Running",
+        field_selector="involvedObject.kind=SparkApplication,involvedObject.name=spark-app",
     )
-
     mock_stream.assert_any_call(
         mock_kubernetes_hook.return_value.core_v1_client.read_namespaced_pod_log,
         name="spark-app-driver",
         namespace="default",
-        _preload_content=False,
         timestamps=True,
     )
 
+    assert operator_output == {"metadata": {"name": "spark-app", "creationTimestamp": "2022-01-01T00:00:00Z"}}
+
+
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.on_kill")
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.Watch.stream")
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
+def test_raise_exception_when_job_fails(
+    mock_kubernetes_hook, mock_load_body_to_dict, mock_stream, mock_on_kill
+):
+    mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}}
+
+    mock_kubernetes_hook.return_value.create_custom_object.return_value = {
+        "metadata": {"name": "spark-app", "creationTimestamp": "2022-01-01T00:00:00Z"}
+    }
+    mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
+
+    object_mock = MagicMock()
+    object_mock.reason = "SparkApplicationFailed"
+    object_mock.message = "spark-app submission failed"
+    object_mock.last_timestamp = datetime(2022, 1, 1, 23, 59, 59, tzinfo=tz.tzutc())
+
+    mock_stream.side_effect = [[{"object": object_mock}], []]
+    op = SparkKubernetesOperator(task_id="task_id", application_file="application_file", watch=True)
+    with pytest.raises(AirflowException, match="spark-app submission failed"):
+        op.execute({})
+
+    assert mock_on_kill.has_called_once()
+
+
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
+def test_execute_without_watch(mock_kubernetes_hook, mock_load_body_to_dict):
+    mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}}
+
+    mock_kubernetes_hook.return_value.create_custom_object.return_value = {
+        "metadata": {"name": "spark-app", "creationTimestamp": "2022-01-01T00:00:00Z"}
+    }
+    mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
+
+    op = SparkKubernetesOperator(task_id="task_id", application_file="application_file")
+    operator_output = op.execute({})
+
+    mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with(
+        group="sparkoperator.k8s.io",
+        version="v1beta2",
+        plural="sparkapplications",
+        body={"metadata": {"name": "spark-app"}},
+        namespace="default",
+    )
+    assert operator_output == {"metadata": {"name": "spark-app", "creationTimestamp": "2022-01-01T00:00:00Z"}}
+
 
 @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
 @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")