You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/03/20 19:43:55 UTC

[airflow] branch master updated: Add ability to specify api group and version for Spark operators (#14898)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 00453dc  Add ability to specify api group and version for Spark operators (#14898)
00453dc is described below

commit 00453dc4a2d41da6c46e73cd66cac88e7556de71
Author: Arkadiy Krava <ar...@ex.ua>
AuthorDate: Sat Mar 20 21:43:37 2021 +0200

    Add ability to specify api group and version for Spark operators (#14898)
    
    closes: #14897
    There were add two parameters for SparkKubernetesOperator and SparkKubernetesSensor. I've placed them in the end and provided default values, so there will be backward compatibility.
    Also added description and tests.
---
 .../cncf/kubernetes/operators/spark_kubernetes.py  | 12 +++++++++--
 .../cncf/kubernetes/sensors/spark_kubernetes.py    | 12 +++++++++--
 .../kubernetes/operators/test_spark_kubernetes.py  | 24 +++++++++++++++++++++
 .../kubernetes/sensors/test_spark_kubernetes.py    | 25 ++++++++++++++++++++++
 4 files changed, 69 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index 1589d59..db2e589 100644
--- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -36,6 +36,10 @@ class SparkKubernetesOperator(BaseOperator):
     :type namespace: str
     :param kubernetes_conn_id: the connection to Kubernetes cluster
     :type kubernetes_conn_id: str
+    :param api_group: kubernetes api group of sparkApplication
+    :type api_group: str
+    :param api_version: kubernetes api version of sparkApplication
+    :type api_version: str
     """
 
     template_fields = ['application_file', 'namespace']
@@ -49,19 +53,23 @@ class SparkKubernetesOperator(BaseOperator):
         application_file: str,
         namespace: Optional[str] = None,
         kubernetes_conn_id: str = 'kubernetes_default',
+        api_group: str = 'sparkoperator.k8s.io',
+        api_version: str = 'v1beta2',
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
         self.application_file = application_file
         self.namespace = namespace
         self.kubernetes_conn_id = kubernetes_conn_id
+        self.api_group = api_group
+        self.api_version = api_version
 
     def execute(self, context):
         self.log.info("Creating sparkApplication")
         hook = KubernetesHook(conn_id=self.kubernetes_conn_id)
         response = hook.create_custom_object(
-            group="sparkoperator.k8s.io",
-            version="v1beta2",
+            group=self.api_group,
+            version=self.api_version,
             plural="sparkapplications",
             body=self.application_file,
             namespace=self.namespace,
diff --git a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py
index eb555f1..feb5922 100644
--- a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py
@@ -41,6 +41,10 @@ class SparkKubernetesSensor(BaseSensorOperator):
     :type kubernetes_conn_id: str
     :param attach_log: determines whether logs for driver pod should be appended to the sensor log
     :type attach_log: bool
+    :param api_group: kubernetes api group of sparkApplication
+    :type api_group: str
+    :param api_version: kubernetes api version of sparkApplication
+    :type api_version: str
     """
 
     template_fields = ("application_name", "namespace")
@@ -55,6 +59,8 @@ class SparkKubernetesSensor(BaseSensorOperator):
         attach_log: bool = False,
         namespace: Optional[str] = None,
         kubernetes_conn_id: str = "kubernetes_default",
+        api_group: str = 'sparkoperator.k8s.io',
+        api_version: str = 'v1beta2',
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -63,6 +69,8 @@ class SparkKubernetesSensor(BaseSensorOperator):
         self.namespace = namespace
         self.kubernetes_conn_id = kubernetes_conn_id
         self.hook = KubernetesHook(conn_id=self.kubernetes_conn_id)
+        self.api_group = api_group
+        self.api_version = api_version
 
     def _log_driver(self, application_state: str, response: dict) -> None:
         if not self.attach_log:
@@ -93,8 +101,8 @@ class SparkKubernetesSensor(BaseSensorOperator):
     def poke(self, context: Dict) -> bool:
         self.log.info("Poking: %s", self.application_name)
         response = self.hook.get_custom_object(
-            group="sparkoperator.k8s.io",
-            version="v1beta2",
+            group=self.api_group,
+            version=self.api_version,
             plural="sparkapplications",
             name=self.application_name,
             namespace=self.namespace,
diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
index 49e3ec0..ed99a17 100644
--- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -213,6 +213,30 @@ class TestSparkKubernetesOperator(unittest.TestCase):
         )
 
     @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object')
+    def test_create_application_from_json_with_api_group_and_version(
+        self, mock_create_namespaced_crd, mock_kubernetes_hook
+    ):
+        api_group = 'sparkoperator.example.com'
+        api_version = 'v1alpha1'
+        op = SparkKubernetesOperator(
+            application_file=TEST_VALID_APPLICATION_JSON,
+            dag=self.dag,
+            kubernetes_conn_id='kubernetes_default_kube_config',
+            task_id='test_task_id',
+            api_group=api_group,
+            api_version=api_version,
+        )
+        op.execute(None)
+        mock_kubernetes_hook.assert_called_once_with()
+        mock_create_namespaced_crd.assert_called_with(
+            body=TEST_APPLICATION_DICT,
+            group=api_group,
+            namespace='default',
+            plural='sparkapplications',
+            version=api_version,
+        )
+
+    @patch('kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object')
     def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernetes_hook):
         op = SparkKubernetesOperator(
             application_file=TEST_VALID_APPLICATION_JSON,
diff --git a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py
index 2bfb50e..b4e5986 100644
--- a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py
@@ -664,6 +664,31 @@ class TestSparkKubernetesSensor(unittest.TestCase):
         "kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object",
         return_value=TEST_COMPLETED_APPLICATION,
     )
+    def test_api_group_and_version_from_sensor(self, mock_get_namespaced_crd, mock_kubernetes_hook):
+        api_group = 'sparkoperator.example.com'
+        api_version = 'v1alpha1'
+        sensor = SparkKubernetesSensor(
+            application_name="spark_pi",
+            dag=self.dag,
+            kubernetes_conn_id="kubernetes_with_namespace",
+            task_id="test_task_id",
+            api_group=api_group,
+            api_version=api_version,
+        )
+        sensor.poke(None)
+        mock_kubernetes_hook.assert_called_once_with()
+        mock_get_namespaced_crd.assert_called_once_with(
+            group=api_group,
+            name="spark_pi",
+            namespace="mock_namespace",
+            plural="sparkapplications",
+            version=api_version,
+        )
+
+    @patch(
+        "kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object",
+        return_value=TEST_COMPLETED_APPLICATION,
+    )
     def test_namespace_from_connection(self, mock_get_namespaced_crd, mock_kubernetes_hook):
         sensor = SparkKubernetesSensor(
             application_name="spark_pi",