You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/06/27 20:55:59 UTC
[airflow] branch main updated: fix spark-kubernetes-operator compatibality (#31798)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6693bdd72d fix spark-kubernetes-operator compatibality (#31798)
6693bdd72d is described below
commit 6693bdd72d70989f4400b5807e2945d814a83b85
Author: Hossein Torabi <ho...@booking.com>
AuthorDate: Tue Jun 27 22:55:51 2023 +0200
fix spark-kubernetes-operator compatibality (#31798)
---
.../cncf/kubernetes/operators/spark_kubernetes.py | 89 +++++++++++++++-------
.../kubernetes/operators/test_spark_kubernetes.py | 84 +++++++++++++++++---
2 files changed, 136 insertions(+), 37 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index 2206e04a73..ccc13469fa 100644
--- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -17,10 +17,12 @@
# under the License.
from __future__ import annotations
+import datetime
from typing import TYPE_CHECKING, Sequence
from kubernetes.watch import Watch
+from airflow import AirflowException
from airflow.models import BaseOperator
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook, _load_body_to_dict
@@ -43,6 +45,7 @@ class SparkKubernetesOperator(BaseOperator):
for the to Kubernetes cluster.
:param api_group: kubernetes api group of sparkApplication
:param api_version: kubernetes api version of sparkApplication
+ :param watch: whether to watch the job status and logs or not
"""
template_fields: Sequence[str] = ("application_file", "namespace")
@@ -60,6 +63,7 @@ class SparkKubernetesOperator(BaseOperator):
in_cluster: bool | None = None,
cluster_context: str | None = None,
config_file: str | None = None,
+ watch: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -72,6 +76,7 @@ class SparkKubernetesOperator(BaseOperator):
self.in_cluster = in_cluster
self.cluster_context = cluster_context
self.config_file = config_file
+ self.watch = watch
self.hook = KubernetesHook(
conn_id=self.kubernetes_conn_id,
@@ -84,35 +89,67 @@ class SparkKubernetesOperator(BaseOperator):
body = _load_body_to_dict(self.application_file)
name = body["metadata"]["name"]
namespace = self.namespace or self.hook.get_namespace()
- namespace_event_stream = Watch().stream(
- self.hook.core_v1_client.list_namespaced_pod,
- namespace=namespace,
- _preload_content=False,
- watch=True,
- label_selector=f"sparkoperator.k8s.io/app-name={name},spark-role=driver",
- field_selector="status.phase=Running",
- )
- self.hook.create_custom_object(
- group=self.api_group,
- version=self.api_version,
- plural=self.plural,
- body=body,
- namespace=namespace,
- )
- for event in namespace_event_stream:
- if event["type"] == "ADDED":
- pod_log_stream = Watch().stream(
- self.hook.core_v1_client.read_namespaced_pod_log,
- name=f"{name}-driver",
+ response = None
+ is_job_created = False
+ if self.watch:
+ try:
+ namespace_event_stream = Watch().stream(
+ self.hook.core_v1_client.list_namespaced_event,
+ namespace=namespace,
+ watch=True,
+ field_selector=f"involvedObject.kind=SparkApplication,involvedObject.name={name}",
+ )
+
+ response = self.hook.create_custom_object(
+ group=self.api_group,
+ version=self.api_version,
+ plural=self.plural,
+ body=body,
namespace=namespace,
- _preload_content=False,
- timestamps=True,
)
- for line in pod_log_stream:
- self.log.info(line)
- else:
- break
+ is_job_created = True
+ for event in namespace_event_stream:
+ obj = event["object"]
+ if event["object"].last_timestamp >= datetime.datetime.strptime(
+ response["metadata"]["creationTimestamp"], "%Y-%m-%dT%H:%M:%S%z"
+ ):
+ self.log.info(obj.message)
+ if obj.reason == "SparkDriverRunning":
+ pod_log_stream = Watch().stream(
+ self.hook.core_v1_client.read_namespaced_pod_log,
+ name=f"{name}-driver",
+ namespace=namespace,
+ timestamps=True,
+ )
+ for line in pod_log_stream:
+ self.log.info(line)
+ elif obj.reason in [
+ "SparkApplicationSubmissionFailed",
+ "SparkApplicationFailed",
+ "SparkApplicationDeleted",
+ ]:
+ is_job_created = False
+ raise AirflowException(obj.message)
+ elif obj.reason == "SparkApplicationCompleted":
+ break
+ else:
+ continue
+ except Exception:
+ if is_job_created:
+ self.on_kill()
+ raise
+
+ else:
+ response = self.hook.create_custom_object(
+ group=self.api_group,
+ version=self.api_version,
+ plural=self.plural,
+ body=body,
+ namespace=namespace,
+ )
+
+ return response
def on_kill(self) -> None:
body = _load_body_to_dict(self.application_file)
diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
index 7ea67aec80..c7f30dad24 100644
--- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -16,8 +16,13 @@
# under the License.
from __future__ import annotations
-from unittest.mock import patch
+from datetime import datetime
+from unittest.mock import MagicMock, patch
+import pytest
+from dateutil import tz
+
+from airflow import AirflowException
from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator
@@ -43,13 +48,22 @@ def test_spark_kubernetes_operator(mock_kubernetes_hook):
@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.Watch.stream")
@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
-def test_execute(mock_kubernetes_hook, mock_load_body_to_dict, mock_stream):
+def test_execute_with_watch(mock_kubernetes_hook, mock_load_body_to_dict, mock_stream):
mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}}
+
+ mock_kubernetes_hook.return_value.create_custom_object.return_value = {
+ "metadata": {"name": "spark-app", "creationTimestamp": "2022-01-01T00:00:00Z"}
+ }
mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
- mock_stream.side_effect = [[{"type": "ADDED"}], []]
- op = SparkKubernetesOperator(task_id="task_id", application_file="application_file")
- op.execute({})
+ object_mock = MagicMock()
+ object_mock.reason = "SparkDriverRunning"
+ object_mock.last_timestamp = datetime(2022, 1, 1, 23, 59, 59, tzinfo=tz.tzutc())
+ mock_stream.side_effect = [[{"object": object_mock}], []]
+
+ op = SparkKubernetesOperator(task_id="task_id", application_file="application_file", watch=True)
+ operator_output = op.execute({})
+
mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with(
group="sparkoperator.k8s.io",
version="v1beta2",
@@ -60,22 +74,70 @@ def test_execute(mock_kubernetes_hook, mock_load_body_to_dict, mock_stream):
assert mock_stream.call_count == 2
mock_stream.assert_any_call(
- mock_kubernetes_hook.return_value.core_v1_client.list_namespaced_pod,
+ mock_kubernetes_hook.return_value.core_v1_client.list_namespaced_event,
namespace="default",
- _preload_content=False,
watch=True,
- label_selector="sparkoperator.k8s.io/app-name=spark-app,spark-role=driver",
- field_selector="status.phase=Running",
+ field_selector="involvedObject.kind=SparkApplication,involvedObject.name=spark-app",
)
-
mock_stream.assert_any_call(
mock_kubernetes_hook.return_value.core_v1_client.read_namespaced_pod_log,
name="spark-app-driver",
namespace="default",
- _preload_content=False,
timestamps=True,
)
+ assert operator_output == {"metadata": {"name": "spark-app", "creationTimestamp": "2022-01-01T00:00:00Z"}}
+
+
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.on_kill")
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.Watch.stream")
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
+def test_raise_exception_when_job_fails(
+ mock_kubernetes_hook, mock_load_body_to_dict, mock_stream, mock_on_kill
+):
+ mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}}
+
+ mock_kubernetes_hook.return_value.create_custom_object.return_value = {
+ "metadata": {"name": "spark-app", "creationTimestamp": "2022-01-01T00:00:00Z"}
+ }
+ mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
+
+ object_mock = MagicMock()
+ object_mock.reason = "SparkApplicationFailed"
+ object_mock.message = "spark-app submission failed"
+ object_mock.last_timestamp = datetime(2022, 1, 1, 23, 59, 59, tzinfo=tz.tzutc())
+
+ mock_stream.side_effect = [[{"object": object_mock}], []]
+ op = SparkKubernetesOperator(task_id="task_id", application_file="application_file", watch=True)
+ with pytest.raises(AirflowException, match="spark-app submission failed"):
+ op.execute({})
+
+ assert mock_on_kill.has_called_once()
+
+
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
+@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
+def test_execute_without_watch(mock_kubernetes_hook, mock_load_body_to_dict):
+ mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}}
+
+ mock_kubernetes_hook.return_value.create_custom_object.return_value = {
+ "metadata": {"name": "spark-app", "creationTimestamp": "2022-01-01T00:00:00Z"}
+ }
+ mock_kubernetes_hook.return_value.get_namespace.return_value = "default"
+
+ op = SparkKubernetesOperator(task_id="task_id", application_file="application_file")
+ operator_output = op.execute({})
+
+ mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with(
+ group="sparkoperator.k8s.io",
+ version="v1beta2",
+ plural="sparkapplications",
+ body={"metadata": {"name": "spark-app"}},
+ namespace="default",
+ )
+ assert operator_output == {"metadata": {"name": "spark-app", "creationTimestamp": "2022-01-01T00:00:00Z"}}
+
@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict")
@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")