You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by fo...@apache.org on 2018/03/12 11:12:26 UTC

incubator-airflow git commit: [AIRFLOW-2140] Add Kubernetes scheduler to SparkSubmitOperator

Repository: incubator-airflow
Updated Branches:
  refs/heads/master c6bdaf829 -> 64100d2a2


[AIRFLOW-2140] Add Kubernetes scheduler to SparkSubmitOperator

Closes #3112 from RJKeevil/spark-k8s


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/64100d2a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/64100d2a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/64100d2a

Branch: refs/heads/master
Commit: 64100d2a2e0b9dafbd2f0b355d414781d98f41c9
Parents: c6bdaf8
Author: Rob Keevil <ro...@gmail.com>
Authored: Mon Mar 12 12:12:14 2018 +0100
Committer: Fokko Driesprong <fo...@godatadriven.com>
Committed: Mon Mar 12 12:12:14 2018 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/spark_submit_hook.py    | 57 +++++++++++--
 tests/contrib/hooks/test_spark_submit_hook.py | 93 +++++++++++++++++++---
 2 files changed, 136 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/64100d2a/airflow/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py
index a7a083a..ae024a9 100644
--- a/airflow/contrib/hooks/spark_submit_hook.py
+++ b/airflow/contrib/hooks/spark_submit_hook.py
@@ -20,6 +20,8 @@ import time
 from airflow.hooks.base_hook import BaseHook
 from airflow.exceptions import AirflowException
 from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.contrib.kubernetes import kube_client
+from kubernetes.client.rest import ApiException
 
 
 class SparkSubmitHook(BaseHook, LoggingMixin):
@@ -56,8 +58,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
     :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors
     (Default: all the available cores on the worker)
     :type total_executor_cores: int
-    :param executor_cores: (Standalone & YARN only) Number of cores per executor
-    (Default: 2)
+    :param executor_cores: (Standalone, YARN and Kubernetes only) Number of cores per
+    executor (Default: 2)
     :type executor_cores: int
     :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G)
     :type executor_memory: str
@@ -119,13 +121,16 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         self._verbose = verbose
         self._submit_sp = None
         self._yarn_application_id = None
+        self._kubernetes_driver_pod = None
 
         self._connection = self._resolve_connection()
         self._is_yarn = 'yarn' in self._connection['master']
+        self._is_kubernetes = 'k8s' in self._connection['master']
 
         self._should_track_driver_status = self._resolve_should_track_driver_status()
         self._driver_id = None
         self._driver_status = None
+        self._spark_exit_code = None
 
     def _resolve_should_track_driver_status(self):
         """
@@ -142,10 +147,12 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                      'queue': None,
                      'deploy_mode': None,
                      'spark_home': None,
-                     'spark_binary': 'spark-submit'}
+                     'spark_binary': 'spark-submit',
+                     'namespace': 'default'}
 
         try:
-            # Master can be local, yarn, spark://HOST:PORT or mesos://HOST:PORT
+            # Master can be local, yarn, spark://HOST:PORT, mesos://HOST:PORT and
+            # k8s://https://<HOST>:<PORT>
             conn = self.get_connection(self._conn_id)
             if conn.port:
                 conn_data['master'] = "{}:{}".format(conn.host, conn.port)
@@ -158,6 +165,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             conn_data['deploy_mode'] = extra.get('deploy-mode', None)
             conn_data['spark_home'] = extra.get('spark-home', None)
             conn_data['spark_binary'] = extra.get('spark-binary', 'spark-submit')
+            conn_data['namespace'] = extra.get('namespace', 'default')
         except AirflowException:
             self.log.debug(
                 "Could not load connection string %s, defaulting to %s",
@@ -196,6 +204,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         if self._conf:
             for key in self._conf:
                 connection_cmd += ["--conf", "{}={}".format(key, str(self._conf[key]))]
+        if self._is_kubernetes:
+            connection_cmd += ["--conf", "spark.kubernetes.namespace={}".format(
+                self._connection['namespace'])]
         if self._files:
             connection_cmd += ["--files", self._files]
         if self._py_files:
@@ -288,7 +299,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         self._process_spark_submit_log(iter(self._submit_sp.stdout.readline, ''))
         returncode = self._submit_sp.wait()
 
-        if returncode:
+        # Check spark-submit return code. In Kubernetes mode, also check the value
+        # of exit code in the log, as it may differ.
+        if returncode or (self._is_kubernetes and self._spark_exit_code != 0):
             raise AirflowException(
                 "Cannot execute: {}. Error code is: {}.".format(
                     spark_submit_cmd, returncode
@@ -335,6 +348,22 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 match = re.search('(application[0-9_]+)', line)
                 if match:
                     self._yarn_application_id = match.groups()[0]
+                    self.log.info("Identified spark driver id: %s",
+                                  self._yarn_application_id)
+
+            # If we run Kubernetes cluster mode, we want to extract the driver pod id
+            # from the logs so we can kill the application when we stop it unexpectedly
+            if self._is_kubernetes:
+                match = re.search('\s*pod name: ((.+?)-([a-z0-9]+)-driver)', line)
+                if match:
+                    self._kubernetes_driver_pod = match.groups()[0]
+                    self.log.info("Identified spark driver pod: %s",
+                                  self._kubernetes_driver_pod)
+
+                # Store the Spark Exit code
+                match_exit_code = re.search('\s*Exit code: (\d+)', line)
+                if match_exit_code:
+                    self._spark_exit_code = int(match_exit_code.groups()[0])
 
             # if we run in standalone cluster mode and we want to track the driver status
             # we need to extract the driver id from the logs. This allows us to poll for
@@ -468,3 +497,21 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                                              stderr=subprocess.PIPE)
 
                 self.log.info("YARN killed with return code: %s", yarn_kill.wait())
+
+            if self._kubernetes_driver_pod:
+                self.log.info('Killing pod %s on Kubernetes', self._kubernetes_driver_pod)
+
+                # Currently only instantiate Kubernetes client for killing a spark pod.
+                try:
+                    client = kube_client.get_kube_client()
+                    api_response = client.delete_namespaced_pod(
+                        self._kubernetes_driver_pod,
+                        self._connection['namespace'],
+                        body=client.V1DeleteOptions(),
+                        pretty=True)
+
+                    self.log.info("Spark on K8s killed with response: %s", api_response)
+
+                except ApiException as e:
+                    self.log.info("Exception when attempting to kill Spark on K8s:")
+                    self.log.exception(e)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/64100d2a/tests/contrib/hooks/test_spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py
index 7821d02..c167797 100644
--- a/tests/contrib/hooks/test_spark_submit_hook.py
+++ b/tests/contrib/hooks/test_spark_submit_hook.py
@@ -74,6 +74,14 @@ class TestSparkSubmitHook(unittest.TestCase):
         )
         db.merge_conn(
             models.Connection(
+                conn_id='spark_k8s_cluster', conn_type='spark',
+                host='k8s://https://k8s-master',
+                extra='{"spark-home": "/opt/spark", ' +
+                      '"deploy-mode": "cluster", ' +
+                      '"namespace": "mynamespace"}')
+        )
+        db.merge_conn(
+            models.Connection(
                 conn_id='spark_default_mesos', conn_type='spark',
                 host='mesos://host', port=5050)
         )
@@ -166,6 +174,7 @@ class TestSparkSubmitHook(unittest.TestCase):
         # Given
         hook_default = SparkSubmitHook(conn_id='')
         hook_spark_yarn_cluster = SparkSubmitHook(conn_id='spark_yarn_cluster')
+        hook_spark_k8s_cluster = SparkSubmitHook(conn_id='spark_k8s_cluster')
         hook_spark_default_mesos = SparkSubmitHook(conn_id='spark_default_mesos')
         hook_spark_home_set = SparkSubmitHook(conn_id='spark_home_set')
         hook_spark_home_not_set = SparkSubmitHook(conn_id='spark_home_not_set')
@@ -180,6 +189,8 @@ class TestSparkSubmitHook(unittest.TestCase):
             ._resolve_should_track_driver_status()
         should_track_driver_status_spark_yarn_cluster = hook_spark_yarn_cluster \
             ._resolve_should_track_driver_status()
+        should_track_driver_status_spark_k8s_cluster = hook_spark_k8s_cluster \
+            ._resolve_should_track_driver_status()
         should_track_driver_status_spark_default_mesos = hook_spark_default_mesos \
             ._resolve_should_track_driver_status()
         should_track_driver_status_spark_home_set = hook_spark_home_set \
@@ -196,6 +207,7 @@ class TestSparkSubmitHook(unittest.TestCase):
         # Then
         self.assertEqual(should_track_driver_status_default, False)
         self.assertEqual(should_track_driver_status_spark_yarn_cluster, False)
+        self.assertEqual(should_track_driver_status_spark_k8s_cluster, False)
         self.assertEqual(should_track_driver_status_spark_default_mesos, False)
         self.assertEqual(should_track_driver_status_spark_home_set, False)
         self.assertEqual(should_track_driver_status_spark_home_not_set, False)
@@ -217,7 +229,8 @@ class TestSparkSubmitHook(unittest.TestCase):
                                      "spark_binary": "spark-submit",
                                      "deploy_mode": None,
                                      "queue": None,
-                                     "spark_home": None}
+                                     "spark_home": None,
+                                     "namespace": 'default'}
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(dict_cmd["--master"], "yarn")
 
@@ -235,7 +248,8 @@ class TestSparkSubmitHook(unittest.TestCase):
                                      "spark_binary": "spark-submit",
                                      "deploy_mode": None,
                                      "queue": "root.default",
-                                     "spark_home": None}
+                                     "spark_home": None,
+                                     "namespace": 'default'}
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(dict_cmd["--master"], "yarn")
         self.assertEqual(dict_cmd["--queue"], "root.default")
@@ -254,7 +268,8 @@ class TestSparkSubmitHook(unittest.TestCase):
                                      "spark_binary": "spark-submit",
                                      "deploy_mode": None,
                                      "queue": None,
-                                     "spark_home": None}
+                                     "spark_home": None,
+                                     "namespace": 'default'}
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(dict_cmd["--master"], "mesos://host:5050")
 
@@ -272,12 +287,33 @@ class TestSparkSubmitHook(unittest.TestCase):
                                      "spark_binary": "spark-submit",
                                      "deploy_mode": "cluster",
                                      "queue": "root.etl",
-                                     "spark_home": None}
+                                     "spark_home": None,
+                                     "namespace": 'default'}
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(dict_cmd["--master"], "yarn://yarn-master")
         self.assertEqual(dict_cmd["--queue"], "root.etl")
         self.assertEqual(dict_cmd["--deploy-mode"], "cluster")
 
+    def test_resolve_connection_spark_k8s_cluster_connection(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='spark_k8s_cluster')
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
+
+        # Then
+        dict_cmd = self.cmd_args_to_dict(cmd)
+        expected_spark_connection = {"spark_home": "/opt/spark",
+                                     "queue": None,
+                                     "spark_binary": "spark-submit",
+                                     "master": "k8s://https://k8s-master",
+                                     "deploy_mode": "cluster",
+                                     "namespace": "mynamespace"}
+        self.assertEqual(connection, expected_spark_connection)
+        self.assertEqual(dict_cmd["--master"], "k8s://https://k8s-master")
+        self.assertEqual(dict_cmd["--deploy-mode"], "cluster")
+
     def test_resolve_connection_spark_home_set_connection(self):
         # Given
         hook = SparkSubmitHook(conn_id='spark_home_set')
@@ -291,7 +327,8 @@ class TestSparkSubmitHook(unittest.TestCase):
                                      "spark_binary": "spark-submit",
                                      "deploy_mode": None,
                                      "queue": None,
-                                     "spark_home": "/opt/myspark"}
+                                     "spark_home": "/opt/myspark",
+                                     "namespace": 'default'}
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(cmd[0], '/opt/myspark/bin/spark-submit')
 
@@ -308,7 +345,8 @@ class TestSparkSubmitHook(unittest.TestCase):
                                      "spark_binary": "spark-submit",
                                      "deploy_mode": None,
                                      "queue": None,
-                                     "spark_home": None}
+                                     "spark_home": None,
+                                     "namespace": 'default'}
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(cmd[0], 'spark-submit')
 
@@ -325,7 +363,8 @@ class TestSparkSubmitHook(unittest.TestCase):
                                      "spark_binary": "custom-spark-submit",
                                      "deploy_mode": None,
                                      "queue": None,
-                                     "spark_home": None}
+                                     "spark_home": None,
+                                     "namespace": 'default'}
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(cmd[0], 'custom-spark-submit')
 
@@ -342,7 +381,8 @@ class TestSparkSubmitHook(unittest.TestCase):
                                      "spark_binary": "custom-spark-submit",
                                      "deploy_mode": None,
                                      "queue": None,
-                                     "spark_home": "/path/to/spark_home"}
+                                     "spark_home": "/path/to/spark_home",
+                                     "namespace": 'default'}
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(cmd[0], '/path/to/spark_home/bin/custom-spark-submit')
 
@@ -359,7 +399,8 @@ class TestSparkSubmitHook(unittest.TestCase):
                                      "spark_binary": "spark-submit",
                                      "deploy_mode": "cluster",
                                      "queue": None,
-                                     "spark_home": "/path/to/spark_home"}
+                                     "spark_home": "/path/to/spark_home",
+                                     "namespace": 'default'}
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(cmd[0], '/path/to/spark_home/bin/spark-submit')
 
@@ -383,6 +424,40 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         self.assertEqual(hook._yarn_application_id, 'application_1486558679801_1820')
 
+    def test_process_spark_submit_log_k8s(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='spark_k8s_cluster')
+        log_lines = [
+            'INFO  LoggingPodStatusWatcherImpl:54 - State changed, new state:' +
+            'pod name: spark-pi-edf2ace37be7353a958b38733a12f8e6-driver' +
+            'namespace: default' +
+            'labels: spark-app-selector -> spark-465b868ada474bda82ccb84ab2747fcd,' +
+            'spark-role -> driver' +
+            'pod uid: ba9c61f6-205f-11e8-b65f-d48564c88e42' +
+            'creation time: 2018-03-05T10:26:55Z' +
+            'service account name: spark' +
+            'volumes: spark-init-properties, download-jars-volume,' +
+            'download-files-volume, spark-token-2vmlm' +
+            'node name: N/A' +
+            'start time: N/A' +
+            'container images: N/A' +
+            'phase: Pending' +
+            'status: []' +
+            '2018-03-05 11:26:56 INFO  LoggingPodStatusWatcherImpl:54 - State changed,' +
+            ' new state:' +
+            'pod name: spark-pi-edf2ace37be7353a958b38733a12f8e6-driver' +
+            'namespace: default' +
+            'Exit code: 999'
+        ]
+
+        # When
+        hook._process_spark_submit_log(log_lines)
+
+        # Then
+        self.assertEqual(hook._kubernetes_driver_pod,
+                         'spark-pi-edf2ace37be7353a958b38733a12f8e6-driver')
+        self.assertEqual(hook._spark_exit_code, 999)
+
     def test_process_spark_submit_log_standalone_cluster(self):
         # Given
         hook = SparkSubmitHook(conn_id='spark_standalone_cluster')