You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by hu...@apache.org on 2024/02/09 08:17:40 UTC

(airflow) branch main updated: Fix rendering `SparkKubernetesOperator.template_body` (#37271)

This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f691adf710 Fix rendering `SparkKubernetesOperator.template_body` (#37271)
f691adf710 is described below

commit f691adf7105b687b6ba2885c8977607065856fd3
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Fri Feb 9 12:17:33 2024 +0400

    Fix rendering `SparkKubernetesOperator.template_body` (#37271)
---
 .../cncf/kubernetes/operators/spark_kubernetes.py        |  6 +++++-
 .../cncf/kubernetes/operators/test_spark_kubernetes.py   | 16 ++++++++++++++++
 2 files changed, 21 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index 83d6f484b3..0c177510eb 100644
--- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -107,7 +107,6 @@ class SparkKubernetesOperator(KubernetesPodOperator):
         self.get_logs = get_logs
         self.log_events_on_failure = log_events_on_failure
         self.success_run_history_limit = success_run_history_limit
-        self.template_body = self.manage_template_specs()
 
     def _render_nested_template_fields(
         self,
@@ -193,6 +192,11 @@ class SparkKubernetesOperator(KubernetesPodOperator):
     def _try_numbers_match(context, pod) -> bool:
         return pod.metadata.labels["try_number"] == context["ti"].try_number
 
+    @property
+    def template_body(self):
+        """Templated body for CustomObjectLauncher."""
+        return self.manage_template_specs()
+
     def find_spark_job(self, context):
         labels = self.create_labels_for_pod(context, include_try_number=False)
         label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver"
diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
index d900ef78b4..5d16406926 100644
--- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -25,6 +25,7 @@ from unittest import mock
 from unittest.mock import patch
 
 import pendulum
+import pytest
 import yaml
 from kubernetes.client import models as k8s
 
@@ -488,3 +489,18 @@ class TestSparkKubernetesOperator:
 
         assert op.launcher.body["spec"]["driver"]["tolerations"] == [toleration]
         assert op.launcher.body["spec"]["executor"]["tolerations"] == [toleration]
+
+
+@pytest.mark.db_test
+def test_template_body_templating(create_task_instance_of_operator):
+    ti = create_task_instance_of_operator(
+        SparkKubernetesOperator,
+        template_spec={"foo": "{{ ds }}", "bar": "{{ dag_run.dag_id }}"},
+        kubernetes_conn_id="kubernetes_default_kube_config",
+        dag_id="test_template_body_templating_dag",
+        task_id="test_template_body_templating_task",
+        execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+    )
+    ti.render_templates()
+    task: SparkKubernetesOperator = ti.task
+    assert task.template_body == {"spark": {"foo": "2024-02-01", "bar": "test_template_body_templating_dag"}}