You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by hu...@apache.org on 2024/02/09 08:17:40 UTC
(airflow) branch main updated: Fix rendering `SparkKubernetesOperator.template_body` (#37271)
This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f691adf710 Fix rendering `SparkKubernetesOperator.template_body` (#37271)
f691adf710 is described below
commit f691adf7105b687b6ba2885c8977607065856fd3
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Fri Feb 9 12:17:33 2024 +0400
Fix rendering `SparkKubernetesOperator.template_body` (#37271)
---
.../cncf/kubernetes/operators/spark_kubernetes.py | 6 +++++-
.../cncf/kubernetes/operators/test_spark_kubernetes.py | 16 ++++++++++++++++
2 files changed, 21 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index 83d6f484b3..0c177510eb 100644
--- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -107,7 +107,6 @@ class SparkKubernetesOperator(KubernetesPodOperator):
self.get_logs = get_logs
self.log_events_on_failure = log_events_on_failure
self.success_run_history_limit = success_run_history_limit
- self.template_body = self.manage_template_specs()
def _render_nested_template_fields(
self,
@@ -193,6 +192,11 @@ class SparkKubernetesOperator(KubernetesPodOperator):
def _try_numbers_match(context, pod) -> bool:
return pod.metadata.labels["try_number"] == context["ti"].try_number
+ @property
+ def template_body(self):
+ """Templated body for CustomObjectLauncher."""
+ return self.manage_template_specs()
+
def find_spark_job(self, context):
labels = self.create_labels_for_pod(context, include_try_number=False)
label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver"
diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
index d900ef78b4..5d16406926 100644
--- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -25,6 +25,7 @@ from unittest import mock
from unittest.mock import patch
import pendulum
+import pytest
import yaml
from kubernetes.client import models as k8s
@@ -488,3 +489,18 @@ class TestSparkKubernetesOperator:
assert op.launcher.body["spec"]["driver"]["tolerations"] == [toleration]
assert op.launcher.body["spec"]["executor"]["tolerations"] == [toleration]
+
+
+@pytest.mark.db_test
+def test_template_body_templating(create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ SparkKubernetesOperator,
+ template_spec={"foo": "{{ ds }}", "bar": "{{ dag_run.dag_id }}"},
+ kubernetes_conn_id="kubernetes_default_kube_config",
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: SparkKubernetesOperator = ti.task
+ assert task.template_body == {"spark": {"foo": "2024-02-01", "bar": "test_template_body_templating_dag"}}