You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/04/07 17:21:31 UTC

incubator-airflow git commit: [AIRFLOW-1085] Enhance the SparkSubmitOperator

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 35e43f506 -> 0ade066f4


[AIRFLOW-1085] Enhance the SparkSubmitOperator

- Allow the Spark home to be set on per connection
basis to obviate
  the need for the spark-submit to be on the PATH,
and allows different
  versions of Spark to be easily used.
- Enable the use of the --driver-memory parameter
on the spark-submit
  by making it parameter on the operator
- Enable the use of the --class parameter on the
spark-submit by making
  it a parameter on the operator

Closes #2211 from camshrun/sparkSubmitImprovements


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/0ade066f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/0ade066f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/0ade066f

Branch: refs/heads/master
Commit: 0ade066f44257c5e119b292f4cc2ba105774f4e7
Parents: 35e43f5
Author: Stephan Werges <sw...@accertify.com>
Authored: Fri Apr 7 19:20:46 2017 +0200
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Fri Apr 7 19:20:58 2017 +0200

----------------------------------------------------------------------
 airflow/contrib/hooks/spark_submit_hook.py      | 32 ++++++++++--
 .../contrib/operators/spark_submit_operator.py  | 13 ++++-
 tests/contrib/hooks/spark_submit_hook.py        | 51 +++++++++++++++++---
 .../contrib/operators/spark_submit_operator.py  |  8 ++-
 4 files changed, 90 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0ade066f/airflow/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py
index 619cc71..59d28b5 100644
--- a/airflow/contrib/hooks/spark_submit_hook.py
+++ b/airflow/contrib/hooks/spark_submit_hook.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 #
 import logging
+import os
 import subprocess
 import re
 
@@ -25,7 +26,8 @@ log = logging.getLogger(__name__)
 class SparkSubmitHook(BaseHook):
     """
     This hook is a wrapper around the spark-submit binary to kick off a spark-submit job.
-    It requires that the "spark-submit" binary is in the PATH.
+    It requires that the "spark-submit" binary is in the PATH or the spark_home to be 
+    supplied.
     :param conf: Arbitrary Spark configuration properties
     :type conf: dict
     :param conn_id: The connection id as configured in Airflow administration. When an
@@ -38,10 +40,14 @@ class SparkSubmitHook(BaseHook):
     :type py_files: str
     :param jars: Submit additional jars to upload and place them in executor classpath.
     :type jars: str
+    :param java_class: the main class of the Java application
+    :type java_class: str
     :param executor_cores: Number of cores per executor (Default: 2)
     :type executor_cores: int
     :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G)
     :type executor_memory: str
+    :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) (Default: 1G)
+    :type driver_memory: str
     :param keytab: Full path to the file that contains the keytab
     :type keytab: str
     :param principal: The name of the kerberos principal used for keytab
@@ -60,8 +66,10 @@ class SparkSubmitHook(BaseHook):
                  files=None,
                  py_files=None,
                  jars=None,
+                 java_class=None,
                  executor_cores=None,
                  executor_memory=None,
+                 driver_memory=None,
                  keytab=None,
                  principal=None,
                  name='default-name',
@@ -72,8 +80,10 @@ class SparkSubmitHook(BaseHook):
         self._files = files
         self._py_files = py_files
         self._jars = jars
+        self._java_class = java_class
         self._executor_cores = executor_cores
         self._executor_memory = executor_memory
+        self._driver_memory = driver_memory
         self._keytab = keytab
         self._principal = principal
         self._name = name
@@ -82,7 +92,7 @@ class SparkSubmitHook(BaseHook):
         self._sp = None
         self._yarn_application_id = None
 
-        (self._master, self._queue, self._deploy_mode) = self._resolve_connection()
+        (self._master, self._queue, self._deploy_mode, self._spark_home) = self._resolve_connection()
         self._is_yarn = 'yarn' in self._master
 
     def _resolve_connection(self):
@@ -90,6 +100,7 @@ class SparkSubmitHook(BaseHook):
         master = 'yarn'
         queue = None
         deploy_mode = None
+        spark_home = None
 
         try:
             # Master can be local, yarn, spark://HOST:PORT or mesos://HOST:PORT
@@ -105,6 +116,8 @@ class SparkSubmitHook(BaseHook):
                 queue = extra['queue']
             if 'deploy-mode' in extra:
                 deploy_mode = extra['deploy-mode']
+            if 'spark-home' in extra:
+                spark_home = extra['spark-home']
         except AirflowException:
             logging.debug(
                 "Could not load connection string {}, defaulting to {}".format(
@@ -112,7 +125,7 @@ class SparkSubmitHook(BaseHook):
                 )
             )
 
-        return master, queue, deploy_mode
+        return master, queue, deploy_mode, spark_home
 
     def get_conn(self):
         pass
@@ -124,8 +137,13 @@ class SparkSubmitHook(BaseHook):
         :type application: str
         :return: full command to be executed
         """
-        # The spark-submit binary needs to be in the path
-        connection_cmd = ["spark-submit"]
+        # If the spark_home is passed then build the spark-submit executable path using
+        # the spark_home; otherwise assume that spark-submit is present in the path to
+        # the executing user
+        if self._spark_home:
+            connection_cmd = [os.path.join(self._spark_home, 'bin', 'spark-submit')]
+        else:
+            connection_cmd = ['spark-submit']
 
         # The url ot the spark master
         connection_cmd += ["--master", self._master]
@@ -145,12 +163,16 @@ class SparkSubmitHook(BaseHook):
             connection_cmd += ["--executor-cores", str(self._executor_cores)]
         if self._executor_memory:
             connection_cmd += ["--executor-memory", self._executor_memory]
+        if self._driver_memory:
+            connection_cmd += ["--driver-memory", self._driver_memory]
         if self._keytab:
             connection_cmd += ["--keytab", self._keytab]
         if self._principal:
             connection_cmd += ["--principal", self._principal]
         if self._name:
             connection_cmd += ["--name", self._name]
+        if self._java_class:
+            connection_cmd += ["--class", self._java_class]
         if self._verbose:
             connection_cmd += ["--verbose"]
         if self._queue:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0ade066f/airflow/contrib/operators/spark_submit_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/spark_submit_operator.py b/airflow/contrib/operators/spark_submit_operator.py
index a5e6145..f62c395 100644
--- a/airflow/contrib/operators/spark_submit_operator.py
+++ b/airflow/contrib/operators/spark_submit_operator.py
@@ -24,7 +24,8 @@ log = logging.getLogger(__name__)
 class SparkSubmitOperator(BaseOperator):
     """
     This hook is a wrapper around the spark-submit binary to kick off a spark-submit job.
-    It requires that the "spark-submit" binary is in the PATH.
+    It requires that the "spark-submit" binary is in the PATH or the spark-home is set
+    in the extra on the connection.
     :param application: The application that submitted as a job, either jar or py file.
     :type application: str
     :param conf: Arbitrary Spark configuration properties
@@ -39,10 +40,14 @@ class SparkSubmitOperator(BaseOperator):
     :type py_files: str
     :param jars: Submit additional jars to upload and place them in executor classpath.
     :type jars: str
+    :param java_class: the main class of the Java application
+    :type java_class: str
     :param executor_cores: Number of cores per executor (Default: 2)
     :type executor_cores: int
     :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G)
     :type executor_memory: str
+    :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) (Default: 1G)
+    :type driver_memory: str
     :param keytab: Full path to the file that contains the keytab
     :type keytab: str
     :param principal: The name of the kerberos principal used for keytab
@@ -63,8 +68,10 @@ class SparkSubmitOperator(BaseOperator):
                  files=None,
                  py_files=None,
                  jars=None,
+                 java_class=None,
                  executor_cores=None,
                  executor_memory=None,
+                 driver_memory=None,
                  keytab=None,
                  principal=None,
                  name='airflow-spark',
@@ -78,8 +85,10 @@ class SparkSubmitOperator(BaseOperator):
         self._files = files
         self._py_files = py_files
         self._jars = jars
+        self._java_class = java_class
         self._executor_cores = executor_cores
         self._executor_memory = executor_memory
+        self._driver_memory = driver_memory
         self._keytab = keytab
         self._principal = principal
         self._name = name
@@ -98,8 +107,10 @@ class SparkSubmitOperator(BaseOperator):
             files=self._files,
             py_files=self._py_files,
             jars=self._jars,
+            java_class=self._java_class,
             executor_cores=self._executor_cores,
             executor_memory=self._executor_memory,
+            driver_memory=self._driver_memory,
             keytab=self._keytab,
             principal=self._principal,
             name=self._name,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0ade066f/tests/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/spark_submit_hook.py b/tests/contrib/hooks/spark_submit_hook.py
index b18925a..8f514c2 100644
--- a/tests/contrib/hooks/spark_submit_hook.py
+++ b/tests/contrib/hooks/spark_submit_hook.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import os
 import unittest
 
 from airflow import configuration, models
@@ -37,7 +37,9 @@ class TestSparkSubmitHook(unittest.TestCase):
         'principal': 'user/spark@airflow.org',
         'name': 'spark-job',
         'num_executors': 10,
-        'verbose': True
+        'verbose': True,
+        'driver_memory': '3g',
+        'java_class': 'com.foo.bar.AppMain'
     }
 
     def setUp(self):
@@ -45,7 +47,7 @@ class TestSparkSubmitHook(unittest.TestCase):
         db.merge_conn(
             models.Connection(
                 conn_id='spark_yarn_cluster', conn_type='spark',
-                host='yarn://yarn-mater', extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
+                host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
         )
         db.merge_conn(
             models.Connection(
@@ -53,6 +55,19 @@ class TestSparkSubmitHook(unittest.TestCase):
                 host='mesos://host', port=5050)
         )
 
+        db.merge_conn(
+            models.Connection(
+                conn_id='spark_home_set', conn_type='spark',
+                host='yarn://yarn-master',
+                extra='{"spark-home": "/opt/myspark"}')
+        )
+
+        db.merge_conn(
+            models.Connection(
+                conn_id='spark_home_not_set', conn_type='spark',
+                host='yarn://yarn-master')
+        )
+
     def test_build_command(self):
         hook = SparkSubmitHook(**self._config)
 
@@ -72,6 +87,8 @@ class TestSparkSubmitHook(unittest.TestCase):
         assert "--principal {}".format(self._config['principal']) in cmd
         assert "--name {}".format(self._config['name']) in cmd
         assert "--num-executors {}".format(self._config['num_executors']) in cmd
+        assert "--class {}".format(self._config['java_class']) in cmd
+        assert "--driver-memory {}".format(self._config['driver_memory']) in cmd
 
         # Check if all config settings are there
         for k in self._config['conf']:
@@ -92,14 +109,14 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # Default to the standard yarn connection because conn_id does not exists
         hook = SparkSubmitHook(conn_id='')
-        self.assertEqual(hook._resolve_connection(), ('yarn', None, None))
+        self.assertEqual(hook._resolve_connection(), ('yarn', None, None, None))
         assert "--master yarn" in ' '.join(hook._build_command(self._spark_job_file))
 
         # Default to the standard yarn connection
         hook = SparkSubmitHook(conn_id='spark_default')
         self.assertEqual(
             hook._resolve_connection(),
-            ('yarn', 'root.default', None)
+            ('yarn', 'root.default', None, None)
         )
         cmd = ' '.join(hook._build_command(self._spark_job_file))
         assert "--master yarn" in cmd
@@ -109,7 +126,7 @@ class TestSparkSubmitHook(unittest.TestCase):
         hook = SparkSubmitHook(conn_id='spark_default_mesos')
         self.assertEqual(
             hook._resolve_connection(),
-            ('mesos://host:5050', None, None)
+            ('mesos://host:5050', None, None, None)
         )
 
         cmd = ' '.join(hook._build_command(self._spark_job_file))
@@ -119,7 +136,7 @@ class TestSparkSubmitHook(unittest.TestCase):
         hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
         self.assertEqual(
             hook._resolve_connection(),
-            ('yarn://yarn-master', 'root.etl', 'cluster')
+            ('yarn://yarn-master', 'root.etl', 'cluster', None)
         )
 
         cmd = ' '.join(hook._build_command(self._spark_job_file))
@@ -127,6 +144,26 @@ class TestSparkSubmitHook(unittest.TestCase):
         assert "--queue root.etl" in cmd
         assert "--deploy-mode cluster" in cmd
 
+        # Set the spark home
+        hook = SparkSubmitHook(conn_id='spark_home_set')
+        self.assertEqual(
+            hook._resolve_connection(),
+            ('yarn://yarn-master', None, None, '/opt/myspark')
+        )
+
+        cmd = ' '.join(hook._build_command(self._spark_job_file))
+        assert cmd.startswith('/opt/myspark/bin/spark-submit')
+
+        # Spark home not set
+        hook = SparkSubmitHook(conn_id='spark_home_not_set')
+        self.assertEqual(
+            hook._resolve_connection(),
+            ('yarn://yarn-master', None, None, None)
+        )
+
+        cmd = ' '.join(hook._build_command(self._spark_job_file))
+        assert cmd.startswith('spark-submit')
+
     def test_process_log(self):
         # Must select yarn connection
         hook = SparkSubmitHook(conn_id='spark_yarn_cluster')

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0ade066f/tests/contrib/operators/spark_submit_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/spark_submit_operator.py b/tests/contrib/operators/spark_submit_operator.py
index c080f76..4e2afb2 100644
--- a/tests/contrib/operators/spark_submit_operator.py
+++ b/tests/contrib/operators/spark_submit_operator.py
@@ -37,7 +37,9 @@ class TestSparkSubmitOperator(unittest.TestCase):
         'name': 'spark-job',
         'num_executors': 10,
         'verbose': True,
-        'application': 'test_application.py'
+        'application': 'test_application.py',
+        'driver_memory': '3g',
+        'java_class': 'com.foo.bar.AppMain'
     }
 
     def setUp(self):
@@ -69,6 +71,10 @@ class TestSparkSubmitOperator(unittest.TestCase):
         self.assertEqual(self._config['name'], operator._name)
         self.assertEqual(self._config['num_executors'], operator._num_executors)
         self.assertEqual(self._config['verbose'], operator._verbose)
+        self.assertEqual(self._config['java_class'], operator._java_class)
+        self.assertEqual(self._config['driver_memory'], operator._driver_memory)
+
+
 
 
 if __name__ == '__main__':