You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/12/31 22:51:39 UTC

[GitHub] stale[bot] closed pull request #2451: [AIRFLOW-1409] Parse ENV connections differently in SparkSubmitHook

stale[bot] closed pull request #2451: [AIRFLOW-1409] Parse ENV connections differently in SparkSubmitHook
URL: https://github.com/apache/incubator-airflow/pull/2451
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py
index a667753470..0494b8e0d2 100644
--- a/airflow/contrib/hooks/spark_submit_hook.py
+++ b/airflow/contrib/hooks/spark_submit_hook.py
@@ -14,15 +14,15 @@
 #
 import logging
 import os
-import subprocess
+import random
 import re
+import subprocess
 
 from airflow.hooks.base_hook import BaseHook
 from airflow.exceptions import AirflowException
 
 log = logging.getLogger(__name__)
 
-
 class SparkSubmitHook(BaseHook):
     """
     This hook is a wrapper around the spark-submit binary to kick off a spark-submit job.
@@ -103,6 +103,15 @@ def __init__(self,
         self._connection = self._resolve_connection()
         self._is_yarn = 'yarn' in self._connection['master']
 
+    @staticmethod
+    def _get_master_url(host, port = None, scheme = None):
+        master_url = host
+        if port:
+            master_url = "{}:{}".format(master_url, port)
+        if scheme:
+            master_url = "{}://{}".format(scheme, master_url)
+        return master_url
+
     def _resolve_connection(self):
         # Build from connection master or default to yarn if not available
         conn_data = {'master': 'yarn',
@@ -112,12 +121,16 @@ def _resolve_connection(self):
                      'spark_binary': 'spark-submit'}
 
         try:
-            # Master can be local, yarn, spark://HOST:PORT or mesos://HOST:PORT
-            conn = self.get_connection(self._conn_id)
-            if conn.port:
-                conn_data['master'] = "{}:{}".format(conn.host, conn.port)
+            # If we use an ENV to define the connection, master can be //local,
+            # //yarn, yarn://HOST:PORT, spark://HOST:PORT or mesos://HOST:PORT
+            conn = self._get_connection_from_env(self._conn_id)
+            if conn:
+                conn_data['master'] = self._get_master_url(conn.host, conn.port, conn.conn_type)
+            # If we the connection comes frmo the database, master can be local,
+            # yarn, yarn://HOST:PORT, spark://HOST:PORT or mesos://HOST:PORT
             else:
-                conn_data['master'] = conn.host
+                conn = random.choice(self._get_connections_from_db(self._conn_id))
+                conn_data['master'] = self._get_master_url(conn.host, conn.port)
 
             # Determine optional yarn queue from the extra field
             extra = conn.extra_dejson
diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py
index 826576f0b8..57670473cf 100644
--- a/tests/contrib/hooks/test_spark_submit_hook.py
+++ b/tests/contrib/hooks/test_spark_submit_hook.py
@@ -12,7 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import os
+import random
 import six
+import string
 import sys
 import unittest
 
@@ -61,6 +64,10 @@ def cmd_args_to_dict(list_cmd):
                 return_dict[arg] = list_cmd[pos+1]
         return return_dict
 
+    @staticmethod
+    def gen_conn_name(length):
+        return ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length))
+
     def setUp(self):
 
         configuration.load_test_config()
@@ -288,6 +295,106 @@ def test_resolve_connection_spark_binary_and_home_set_connection(self):
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(cmd[0], '/path/to/spark_home/bin/custom-spark-submit')
 
+    def test_resolve_connection_yarn_env_connection(self):
+        # Given
+        conn_name = self.gen_conn_name(10)
+        os.environ["AIRFLOW_CONN_SPARK_{}".format(conn_name.upper())] = "//yarn"
+        hook = SparkSubmitHook(conn_id='spark_{}'.format(conn_name))
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        dict_cmd = self.cmd_args_to_dict(cmd)
+        expected_spark_connection = {"master": "yarn",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
+        self.assertEqual(dict_cmd["--master"], "yarn")
+
+    def test_resolve_connection_yarn_cluster_env_connection(self):
+        # Given
+        conn_name = self.gen_conn_name(10)
+        os.environ["AIRFLOW_CONN_SPARK_{}".format(conn_name.upper())] = "yarn://yarn-master"
+        hook = SparkSubmitHook(conn_id='spark_{}'.format(conn_name))
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        dict_cmd = self.cmd_args_to_dict(cmd)
+        expected_spark_connection = {"master": "yarn://yarn-master",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
+        self.assertEqual(dict_cmd["--master"], "yarn://yarn-master")
+
+    def test_resolve_connection_spark_local_env_connection(self):
+        # Given
+        conn_name = self.gen_conn_name(10)
+        os.environ["AIRFLOW_CONN_SPARK_{}".format(conn_name.upper())] = "//local"
+        hook = SparkSubmitHook(conn_id='spark_{}'.format(conn_name))
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        dict_cmd = self.cmd_args_to_dict(cmd)
+        expected_spark_connection = {"master": "local",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
+        self.assertEqual(dict_cmd["--master"], "local")
+
+    def test_resolve_connection_spark_cluster_env_connection(self):
+        # Given
+        conn_name = self.gen_conn_name(10)
+        os.environ["AIRFLOW_CONN_SPARK_{}".format(conn_name.upper())] = "spark://spark-master:7077"
+        hook = SparkSubmitHook(conn_id='spark_{}'.format(conn_name))
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        dict_cmd = self.cmd_args_to_dict(cmd)
+        expected_spark_connection = {"master": "spark://spark-master:7077",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
+        self.assertEqual(dict_cmd["--master"], "spark://spark-master:7077")
+
+    def test_resolve_connection_mesos_cluster_env_connection(self):
+        # Given
+        conn_name = self.gen_conn_name(10)
+        os.environ["AIRFLOW_CONN_SPARK_{}".format(conn_name.upper())] = "mesos://mesos-master:5050"
+        hook = SparkSubmitHook(conn_id='spark_{}'.format(conn_name))
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        dict_cmd = self.cmd_args_to_dict(cmd)
+        expected_spark_connection = {"master": "mesos://mesos-master:5050",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
+        self.assertEqual(dict_cmd["--master"], "mesos://mesos-master:5050")
+
     def test_process_log(self):
         # Given
         hook = SparkSubmitHook(conn_id='spark_yarn_cluster')


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services