You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ir...@apache.org on 2018/09/13 14:20:42 UTC

[4/4] spark git commit: [SPARK-25253][PYSPARK] Refactor local connection & auth code

[SPARK-25253][PYSPARK] Refactor local connection & auth code

This eliminates some duplication in the code to connect to a server on localhost to talk directly to the jvm.  Also it gives consistent ipv6 and error handling.  Two other incidental changes, that shouldn't matter:
1) python barrier tasks perform authentication immediately (rather than waiting for the BARRIER_FUNCTION indicator)
2) for `rdd._load_from_socket`, the timeout is only increased after authentication.

Closes #22247 from squito/py_connection_refactor.

Authored-by: Imran Rashid <ir...@cloudera.com>
Signed-off-by: hyukjinkwon <gu...@apache.org>
(cherry picked from commit 38391c9aa8a88fcebb337934f30298a32d91596b)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a2a54a5f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a2a54a5f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a2a54a5f

Branch: refs/heads/branch-2.3
Commit: a2a54a5f49364a1825932c9f04eb0ff82dd7d465
Parents: 9ac9f36
Author: Imran Rashid <ir...@cloudera.com>
Authored: Wed Aug 29 09:47:38 2018 +0800
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Thu Sep 13 09:19:56 2018 -0500

----------------------------------------------------------------------
 python/pyspark/java_gateway.py | 32 +++++++++++++++++++++++++++++++-
 python/pyspark/rdd.py          | 24 ++----------------------
 python/pyspark/worker.py       |  7 ++-----
 3 files changed, 35 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a2a54a5f/python/pyspark/java_gateway.py
----------------------------------------------------------------------
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 0afbe9d..eed866d 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -134,7 +134,7 @@ def launch_gateway(conf=None):
     return gateway
 
 
-def do_server_auth(conn, auth_secret):
+def _do_server_auth(conn, auth_secret):
     """
     Performs the authentication protocol defined by the SocketAuthHelper class on the given
     file-like object 'conn'.
@@ -145,3 +145,33 @@ def do_server_auth(conn, auth_secret):
     if reply != "ok":
         conn.close()
         raise Exception("Unexpected reply from iterator server.")
+
+
+def local_connect_and_auth(port, auth_secret):
+    """
+    Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
+    Handles IPV4 & IPV6, does some error handling.
+    :param port
+    :param auth_secret
+    :return: a tuple with (sockfile, sock)
+    """
+    sock = None
+    errors = []
+    # Support for both IPv4 and IPv6.
+    # On most of IPv6-ready systems, IPv6 will take precedence.
+    for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
+        af, socktype, proto, _, sa = res
+        try:
+            sock = socket.socket(af, socktype, proto)
+            sock.settimeout(15)
+            sock.connect(sa)
+            sockfile = sock.makefile("rwb", 65536)
+            _do_server_auth(sockfile, auth_secret)
+            return (sockfile, sock)
+        except socket.error as e:
+            emsg = _exception_message(e)
+            errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
+            sock.close()
+            sock = None
+    else:
+        raise Exception("could not open socket: %s" % errors)

http://git-wip-us.apache.org/repos/asf/spark/blob/a2a54a5f/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index eadb5ab..3f6a8e6 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -39,7 +39,7 @@ if sys.version > '3':
 else:
     from itertools import imap as map, ifilter as filter
 
-from pyspark.java_gateway import do_server_auth
+from pyspark.java_gateway import local_connect_and_auth
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
     BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
     PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
@@ -139,30 +139,10 @@ def _parse_memory(s):
 
 
 def _load_from_socket(sock_info, serializer):
-    port, auth_secret = sock_info
-    sock = None
-    # Support for both IPv4 and IPv6.
-    # On most of IPv6-ready systems, IPv6 will take precedence.
-    for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
-        af, socktype, proto, canonname, sa = res
-        sock = socket.socket(af, socktype, proto)
-        try:
-            sock.settimeout(15)
-            sock.connect(sa)
-        except socket.error:
-            sock.close()
-            sock = None
-            continue
-        break
-    if not sock:
-        raise Exception("could not open socket")
+    (sockfile, sock) = local_connect_and_auth(*sock_info)
     # The RDD materialization time is unpredicable, if we set a timeout for socket reading
     # operation, it will very possibly fail. See SPARK-18281.
     sock.settimeout(None)
-
-    sockfile = sock.makefile("rwb", 65536)
-    do_server_auth(sockfile, auth_secret)
-
     # The socket will be automatically closed when garbage-collected.
     return serializer.load_stream(sockfile)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a2a54a5f/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 788b323..812f4b2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -27,7 +27,7 @@ import traceback
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
-from pyspark.java_gateway import do_server_auth
+from pyspark.java_gateway import local_connect_and_auth
 from pyspark.taskcontext import TaskContext
 from pyspark.files import SparkFiles
 from pyspark.rdd import PythonEvalType
@@ -269,8 +269,5 @@ if __name__ == '__main__':
     # Read information about how to connect back to the JVM from the environment.
     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    sock.connect(("127.0.0.1", java_port))
-    sock_file = sock.makefile("rwb", 65536)
-    do_server_auth(sock_file, auth_secret)
+    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
     main(sock_file, sock_file)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org