You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ir...@apache.org on 2018/09/25 16:57:49 UTC
[2/3] spark git commit: [PYSPARK] Updates to pyspark broadcast
[PYSPARK] Updates to pyspark broadcast
(cherry picked from commit 09dd34cb1706f2477a89174d6a1a0f17ed5b0a65)
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dd0e7cf5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dd0e7cf5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dd0e7cf5
Branch: refs/heads/branch-2.2
Commit: dd0e7cf5287148618404593ca095dd900b6e993f
Parents: fc1c4e7
Author: Imran Rashid <ir...@cloudera.com>
Authored: Mon Aug 13 21:35:34 2018 -0500
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Tue Sep 25 11:46:03 2018 -0500
----------------------------------------------------------------------
.../org/apache/spark/api/python/PythonRDD.scala | 349 ++++++++++++++++---
.../spark/api/python/PythonRDDSuite.scala | 23 +-
dev/sparktestsupport/modules.py | 2 +
python/pyspark/broadcast.py | 58 ++-
python/pyspark/context.py | 63 +++-
python/pyspark/serializers.py | 58 +++
python/pyspark/test_broadcast.py | 126 +++++++
python/pyspark/test_serializers.py | 90 +++++
python/pyspark/worker.py | 24 +-
9 files changed, 705 insertions(+), 88 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 7b5a179..2f4e3bc 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -24,8 +24,10 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
import scala.language.existentials
-import scala.util.control.NonFatal
+import scala.util.Try
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
@@ -37,6 +39,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._
@@ -293,19 +296,51 @@ private[spark] class PythonRunner(
val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
- val cnt = toRemove.size + newBids.diff(oldBids).size
+ val addedBids = newBids.diff(oldBids)
+ val cnt = toRemove.size + addedBids.size
+ val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty
+ dataOut.writeBoolean(needsDecryptionServer)
dataOut.writeInt(cnt)
- for (bid <- toRemove) {
- // remove the broadcast from worker
- dataOut.writeLong(- bid - 1) // bid >= 0
- oldBids.remove(bid)
+ def sendBidsToRemove(): Unit = {
+ for (bid <- toRemove) {
+ // remove the broadcast from worker
+ dataOut.writeLong(-bid - 1) // bid >= 0
+ oldBids.remove(bid)
+ }
}
- for (broadcast <- broadcastVars) {
- if (!oldBids.contains(broadcast.id)) {
+ if (needsDecryptionServer) {
+ // if there is encryption, we setup a server which reads the encrypted files, and sends
+ // the decrypted data to python
+ val idsAndFiles = broadcastVars.flatMap { broadcast =>
+ if (oldBids.contains(broadcast.id)) {
+ None
+ } else {
+ Some((broadcast.id, broadcast.value.path))
+ }
+ }
+ val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
+ dataOut.writeInt(server.port)
+ logTrace(s"broadcast decryption server setup on ${server.port}")
+ PythonRDD.writeUTF(server.secret, dataOut)
+ sendBidsToRemove()
+ idsAndFiles.foreach { case (id, _) =>
// send new broadcast
- dataOut.writeLong(broadcast.id)
- PythonRDD.writeUTF(broadcast.value.path, dataOut)
- oldBids.add(broadcast.id)
+ dataOut.writeLong(id)
+ oldBids.add(id)
+ }
+ dataOut.flush()
+ logTrace("waiting for python to read decrypted broadcast data from server")
+ server.waitTillBroadcastDataSent()
+ logTrace("done sending decrypted data to python")
+ } else {
+ sendBidsToRemove()
+ for (broadcast <- broadcastVars) {
+ if (!oldBids.contains(broadcast.id)) {
+ // send new broadcast
+ dataOut.writeLong(broadcast.id)
+ PythonRDD.writeUTF(broadcast.value.path, dataOut)
+ oldBids.add(broadcast.id)
+ }
}
}
dataOut.flush()
@@ -482,27 +517,34 @@ private[spark] object PythonRDD extends Logging {
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
- val file = new DataInputStream(new FileInputStream(filename))
+ readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism)
+ }
+
+ def readRDDFromInputStream(
+ sc: SparkContext,
+ in: InputStream,
+ parallelism: Int): JavaRDD[Array[Byte]] = {
+ val din = new DataInputStream(in)
try {
val objs = new mutable.ArrayBuffer[Array[Byte]]
try {
while (true) {
- val length = file.readInt()
+ val length = din.readInt()
val obj = new Array[Byte](length)
- file.readFully(obj)
+ din.readFully(obj)
objs += obj
}
} catch {
case eof: EOFException => // No-op
}
- JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ JavaRDD.fromRDD(sc.parallelize(objs, parallelism))
} finally {
- file.close()
+ din.close()
}
}
- def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = {
- sc.broadcast(new PythonBroadcast(path))
+ def setupBroadcast(path: String): PythonBroadcast = {
+ new PythonBroadcast(path)
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -712,34 +754,15 @@ private[spark] object PythonRDD extends Logging {
* data collected from this job, and the secret for authentication.
*/
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
- val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
- // Close the socket if no connection in 15 seconds
- serverSocket.setSoTimeout(15000)
-
- new Thread(threadName) {
- setDaemon(true)
- override def run() {
- try {
- val sock = serverSocket.accept()
- authHelper.authClient(sock)
-
- val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
- Utils.tryWithSafeFinally {
- writeIteratorToStream(items, out)
- } {
- out.close()
- sock.close()
- }
- } catch {
- case NonFatal(e) =>
- logError(s"Error while sending iterator", e)
- } finally {
- serverSocket.close()
- }
+ val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
+ val out = new DataOutputStream(new BufferedOutputStream(s.getOutputStream()))
+ Utils.tryWithSafeFinally {
+ writeIteratorToStream(items, out)
+ } {
+ out.close()
}
- }.start()
-
- Array(serverSocket.getLocalPort, authHelper.secret)
+ }
+ Array(port, secret)
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
@@ -957,13 +980,11 @@ private[spark] class PythonAccumulatorV2(
}
}
-/**
- * A Wrapper for Python Broadcast, which is written into disk by Python. It also will
- * write the data into disk after deserialization, then Python can read it from disks.
- */
// scalastyle:off no.finalize
private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
- with Logging {
+ with Logging {
+
+ private var encryptionServer: PythonServer[Unit] = null
/**
* Read data from disks, then copy it to `out`
@@ -1005,5 +1026,233 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
}
}
}
+
+ def setupEncryptionServer(): Array[Any] = {
+ encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") {
+ override def handleConnection(sock: Socket): Unit = {
+ val env = SparkEnv.get
+ val in = sock.getInputStream()
+ val dir = new File(Utils.getLocalDir(env.conf))
+ val file = File.createTempFile("broadcast", "", dir)
+ path = file.getAbsolutePath
+ val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path))
+ DechunkedInputStream.dechunkAndCopyToOutput(in, out)
+ }
+ }
+ Array(encryptionServer.port, encryptionServer.secret)
+ }
+
+ def waitTillDataReceived(): Unit = encryptionServer.getResult()
}
// scalastyle:on no.finalize
+
+/**
+ * The inverse of pyspark's ChunkedStream for sending broadcast data.
+ * Tested from python tests.
+ */
+private[spark] class DechunkedInputStream(wrapped: InputStream) extends InputStream with Logging {
+ private val din = new DataInputStream(wrapped)
+ private var remainingInChunk = din.readInt()
+
+ override def read(): Int = {
+ val into = new Array[Byte](1)
+ val n = read(into, 0, 1)
+ if (n == -1) {
+ -1
+ } else {
+ // if you just cast a byte to an int, then anything > 127 is negative, which is interpreted
+ // as an EOF
+ into(0) & 0xFF
+ }
+ }
+
+ override def read(dest: Array[Byte], off: Int, len: Int): Int = {
+ if (remainingInChunk == -1) {
+ return -1
+ }
+ var destSpace = len
+ var destPos = off
+ while (destSpace > 0 && remainingInChunk != -1) {
+ val toCopy = math.min(remainingInChunk, destSpace)
+ val read = din.read(dest, destPos, toCopy)
+ destPos += read
+ destSpace -= read
+ remainingInChunk -= read
+ if (remainingInChunk == 0) {
+ remainingInChunk = din.readInt()
+ }
+ }
+ assert(destSpace == 0 || remainingInChunk == -1)
+ return destPos - off
+ }
+
+ override def close(): Unit = wrapped.close()
+}
+
+/**
+ * The inverse of pyspark's ChunkedStream for sending data of unknown size.
+ *
+ * We might be serializing a really large object from python -- we don't want
+ * python to buffer the whole thing in memory, nor can it write to a file,
+ * so we don't know the length in advance. So python writes it in chunks, each chunk
+ * preceeded by a length, till we get a "length" of -1 which serves as EOF.
+ *
+ * Tested from python tests.
+ */
+private[spark] object DechunkedInputStream {
+
+ /**
+ * Dechunks the input, copies to output, and closes both input and the output safely.
+ */
+ def dechunkAndCopyToOutput(chunked: InputStream, out: OutputStream): Unit = {
+ val dechunked = new DechunkedInputStream(chunked)
+ Utils.tryWithSafeFinally {
+ Utils.copyStream(dechunked, out)
+ } {
+ JavaUtils.closeQuietly(out)
+ JavaUtils.closeQuietly(dechunked)
+ }
+ }
+}
+
+/**
+ * Creates a server in the jvm to communicate with python for handling one batch of data, with
+ * authentication and error handling.
+ */
+private[spark] abstract class PythonServer[T](
+ authHelper: SocketAuthHelper,
+ threadName: String) {
+
+ def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
+ def this(threadName: String) = this(SparkEnv.get, threadName)
+
+ val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock =>
+ promise.complete(Try(handleConnection(sock)))
+ }
+
+ /**
+ * Handle a connection which has already been authenticated. Any error from this function
+ * will clean up this connection and the entire server, and get propogated to [[getResult]].
+ */
+ def handleConnection(sock: Socket): T
+
+ val promise = Promise[T]()
+
+ /**
+ * Blocks indefinitely for [[handleConnection]] to finish, and returns that result. If
+ * handleConnection throws an exception, this will throw an exception which includes the original
+ * exception as a cause.
+ */
+ def getResult(): T = {
+ getResult(Duration.Inf)
+ }
+
+ def getResult(wait: Duration): T = {
+ ThreadUtils.awaitResult(promise.future, wait)
+ }
+
+}
+
+private[spark] object PythonServer {
+
+ /**
+ * Create a socket server and run user function on the socket in a background thread.
+ *
+ * The socket server can only accept one connection, or close if no connection
+ * in 15 seconds.
+ *
+ * The thread will terminate after the supplied user function, or if there are any exceptions.
+ *
+ * If you need to get a result of the supplied function, create a subclass of [[PythonServer]]
+ *
+ * @return The port number of a local socket and the secret for authentication.
+ */
+ def setupOneConnectionServer(
+ authHelper: SocketAuthHelper,
+ threadName: String)
+ (func: Socket => Unit): (Int, String) = {
+ val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
+ // Close the socket if no connection in 15 seconds
+ serverSocket.setSoTimeout(15000)
+
+ new Thread(threadName) {
+ setDaemon(true)
+ override def run(): Unit = {
+ var sock: Socket = null
+ try {
+ sock = serverSocket.accept()
+ authHelper.authClient(sock)
+ func(sock)
+ } finally {
+ JavaUtils.closeQuietly(serverSocket)
+ JavaUtils.closeQuietly(sock)
+ }
+ }
+ }.start()
+ (serverSocket.getLocalPort, authHelper.secret)
+ }
+}
+
+/**
+ * Sends decrypted broadcast data to python worker. See [[PythonRunner]] for entire protocol.
+ */
+private[spark] class EncryptedPythonBroadcastServer(
+ val env: SparkEnv,
+ val idsAndFiles: Seq[(Long, String)])
+ extends PythonServer[Unit]("broadcast-decrypt-server") with Logging {
+
+ override def handleConnection(socket: Socket): Unit = {
+ val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()))
+ var socketIn: InputStream = null
+ // send the broadcast id, then the decrypted data. We don't need to send the length, the
+ // the python pickle module just needs a stream.
+ Utils.tryWithSafeFinally {
+ (idsAndFiles).foreach { case (id, path) =>
+ out.writeLong(id)
+ val in = env.serializerManager.wrapForEncryption(new FileInputStream(path))
+ Utils.tryWithSafeFinally {
+ Utils.copyStream(in, out, false)
+ } {
+ in.close()
+ }
+ }
+ logTrace("waiting for python to accept broadcast data over socket")
+ out.flush()
+ socketIn = socket.getInputStream()
+ socketIn.read()
+ logTrace("done serving broadcast data")
+ } {
+ JavaUtils.closeQuietly(socketIn)
+ JavaUtils.closeQuietly(out)
+ }
+ }
+
+ def waitTillBroadcastDataSent(): Unit = {
+ getResult()
+ }
+}
+
+/**
+ * Helper for making RDD[Array[Byte]] from some python data, by reading the data from python
+ * over a socket. This is used in preference to writing data to a file when encryption is enabled.
+ */
+private[spark] abstract class PythonRDDServer
+ extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
+
+ def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
+ val in = sock.getInputStream()
+ val dechunkedInput: InputStream = new DechunkedInputStream(in)
+ streamToRDD(dechunkedInput)
+ }
+
+ protected def streamToRDD(input: InputStream): RDD[Array[Byte]]
+
+}
+
+private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
+ extends PythonRDDServer {
+
+ override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
+ PythonRDD.readRDDFromInputStream(sc, input, parallelism)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 05b4e67..6f9b583 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -18,9 +18,13 @@
package org.apache.spark.api.python
import java.io.{ByteArrayOutputStream, DataOutputStream}
+import java.net.{InetAddress, Socket}
import java.nio.charset.StandardCharsets
-import org.apache.spark.SparkFunSuite
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.security.SocketAuthHelper
class PythonRDDSuite extends SparkFunSuite {
@@ -44,4 +48,21 @@ class PythonRDDSuite extends SparkFunSuite {
("a".getBytes(StandardCharsets.UTF_8), null),
(null, "b".getBytes(StandardCharsets.UTF_8))), buffer)
}
+
+ test("python server error handling") {
+ val authHelper = new SocketAuthHelper(new SparkConf())
+ val errorServer = new ExceptionPythonServer(authHelper)
+ val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
+ authHelper.authToServer(client)
+ val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
+ assert(ex.getCause().getMessage().contains("exception within handleConnection"))
+ }
+
+ class ExceptionPythonServer(authHelper: SocketAuthHelper)
+ extends PythonServer[Unit](authHelper, "error-server") {
+
+ override def handleConnection(sock: Socket): Unit = {
+ throw new Exception("exception within handleConnection")
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/dev/sparktestsupport/modules.py
----------------------------------------------------------------------
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2971e0d..b7683f9 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -353,6 +353,8 @@ pyspark_core = Module(
"pyspark.profiler",
"pyspark.shuffle",
"pyspark.tests",
+ "pyspark.test_broadcast",
+ "pyspark.test_serializers",
"pyspark.util",
]
)
http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 02fc515..3f1298e 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -15,13 +15,16 @@
# limitations under the License.
#
+import gc
import os
+import socket
import sys
-import gc
from tempfile import NamedTemporaryFile
import threading
from pyspark.cloudpickle import print_exec
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import ChunkedStream
from pyspark.util import _exception_message
if sys.version < '3':
@@ -64,19 +67,43 @@ class Broadcast(object):
>>> large_broadcast = sc.broadcast(range(10000))
"""
- def __init__(self, sc=None, value=None, pickle_registry=None, path=None):
+ def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
+ sock_file=None):
"""
Should not be called directly by users -- use L{SparkContext.broadcast()}
instead.
"""
if sc is not None:
+ # we're on the driver. We want the pickled data to end up in a file (maybe encrypted)
f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
- self._path = self.dump(value, f)
- self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path)
+ self._path = f.name
+ python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
+ if sc._encryption_enabled:
+ # with encryption, we ask the jvm to do the encryption for us, we send it data
+ # over a socket
+ port, auth_secret = python_broadcast.setupEncryptionServer()
+ (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
+ broadcast_out = ChunkedStream(encryption_sock_file, 8192)
+ else:
+ # no encryption, we can just write pickled data directly to the file from python
+ broadcast_out = f
+ self.dump(value, broadcast_out)
+ if sc._encryption_enabled:
+ python_broadcast.waitTillDataReceived()
+ self._jbroadcast = sc._jsc.broadcast(python_broadcast)
self._pickle_registry = pickle_registry
else:
+ # we're on an executor
self._jbroadcast = None
- self._path = path
+ if sock_file is not None:
+ # the jvm is doing decryption for us. Read the value
+ # immediately from the sock_file
+ self._value = self.load(sock_file)
+ else:
+ # the jvm just dumps the pickled data in path -- we'll unpickle lazily when
+ # the value is requested
+ assert(path is not None)
+ self._path = path
def dump(self, value, f):
try:
@@ -89,24 +116,25 @@ class Broadcast(object):
print_exec(sys.stderr)
raise pickle.PicklingError(msg)
f.close()
- return f.name
- def load(self, path):
+ def load_from_path(self, path):
with open(path, 'rb', 1 << 20) as f:
- # pickle.load() may create lots of objects, disable GC
- # temporary for better performance
- gc.disable()
- try:
- return pickle.load(f)
- finally:
- gc.enable()
+ return self.load(f)
+
+ def load(self, file):
+ # "file" could also be a socket
+ gc.disable()
+ try:
+ return pickle.load(file)
+ finally:
+ gc.enable()
@property
def value(self):
""" Return the broadcasted value
"""
if not hasattr(self, "_value") and self._path is not None:
- self._value = self.load(self._path)
+ self._value = self.load_from_path(self._path)
return self._value
def unpersist(self, blocking=False):
http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 02b1b2d..68e4c17 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -33,9 +33,9 @@ from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
-from pyspark.java_gateway import launch_gateway
+from pyspark.java_gateway import launch_gateway, local_connect_and_auth
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
- PairDeserializer, AutoBatchedSerializer, NoOpSerializer
+ PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.traceback_utils import CallSite, first_spark_call
@@ -189,6 +189,13 @@ class SparkContext(object):
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
self._jsc.sc().register(self._javaAccumulator)
+ # If encryption is enabled, we need to setup a server in the jvm to read broadcast
+ # data via a socket.
+ # scala's mangled names w/ $ in them require special treatment.
+ encryption_conf = self._jvm.org.apache.spark.internal.config.__getattr__("package$")\
+ .__getattr__("MODULE$").IO_ENCRYPTION_ENABLED()
+ self._encryption_enabled = self._jsc.sc().conf().get(encryption_conf)
+
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
self.pythonVer = "%d.%d" % sys.version_info[:2]
@@ -479,25 +486,43 @@ class SparkContext(object):
return xrange(getStart(split), getStart(split + 1), step)
return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
- # Calling the Java parallelize() method with an ArrayList is too slow,
- # because it sends O(n) Py4J commands. As an alternative, serialized
- # objects are written to a file and loaded through textFile().
- tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
- try:
- # Make sure we distribute data evenly if it's smaller than self.batchSize
- if "__len__" not in dir(c):
- c = list(c) # Make it a list so we can compute its length
- batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
- serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
- serializer.dump_stream(c, tempFile)
- tempFile.close()
- readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
- jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
- finally:
- # readRDDFromFile eagerily reads the file so we can delete right after.
- os.unlink(tempFile.name)
+
+ # Make sure we distribute data evenly if it's smaller than self.batchSize
+ if "__len__" not in dir(c):
+ c = list(c) # Make it a list so we can compute its length
+ batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
+ serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
+ jrdd = self._serialize_to_jvm(c, numSlices, serializer)
return RDD(jrdd, self, serializer)
+ def _serialize_to_jvm(self, data, parallelism, serializer):
+ """
+ Using py4j to send a large dataset to the jvm is really slow, so we use either a file
+ or a socket if we have encryption enabled.
+ """
+ if self._encryption_enabled:
+ # with encryption, we open a server in java and send the data directly
+ server = self._jvm.PythonParallelizeServer(self._jsc.sc(), parallelism)
+ (sock_file, _) = local_connect_and_auth(server.port(), server.secret())
+ chunked_out = ChunkedStream(sock_file, 8192)
+ serializer.dump_stream(data, chunked_out)
+ chunked_out.close()
+ # this call will block until the server has read all the data and processed it (or
+ # throws an exception)
+ return server.getResult()
+ else:
+ # without encryption, we serialize to a file, and we read the file in java and
+ # parallelize from there.
+ tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
+ try:
+ serializer.dump_stream(data, tempFile)
+ tempFile.close()
+ readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+ return readRDDFromFile(self._jsc, tempFile.name, parallelism)
+ finally:
+ # we eagerly read the file so we can delete right after.
+ os.unlink(tempFile.name)
+
def pickleFile(self, name, minPartitions=None):
"""
Load an RDD previously saved using L{RDD.saveAsPickleFile} method.
http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 9bd4e55..d351b506 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -586,11 +586,69 @@ def write_int(value, stream):
stream.write(struct.pack("!i", value))
+def read_bool(stream):
+ length = stream.read(1)
+ if not length:
+ raise EOFError
+ return struct.unpack("!?", length)[0]
+
+
def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
+class ChunkedStream(object):
+
+ """
+ This file-like object takes a stream of data, of unknown length, and breaks it into fixed
+ length frames. The intended use case is serializing large data and sending it immediately over
+ a socket -- we do not want to buffer the entire data before sending it, but the receiving end
+ needs to know whether or not there is more data coming.
+
+ It works by buffering the incoming data in some fixed-size chunks. If the buffer is full, it
+ first sends the buffer size, then the data. This repeats as long as there is more data to send.
+ When this is closed, it sends the length of whatever data is in the buffer, then that data, and
+ finally a "length" of -1 to indicate the stream has completed.
+ """
+
+ def __init__(self, wrapped, buffer_size):
+ self.buffer_size = buffer_size
+ self.buffer = bytearray(buffer_size)
+ self.current_pos = 0
+ self.wrapped = wrapped
+
+ def write(self, bytes):
+ byte_pos = 0
+ byte_remaining = len(bytes)
+ while byte_remaining > 0:
+ new_pos = byte_remaining + self.current_pos
+ if new_pos < self.buffer_size:
+ # just put it in our buffer
+ self.buffer[self.current_pos:new_pos] = bytes[byte_pos:]
+ self.current_pos = new_pos
+ byte_remaining = 0
+ else:
+ # fill the buffer, send the length then the contents, and start filling again
+ space_left = self.buffer_size - self.current_pos
+ new_byte_pos = byte_pos + space_left
+ self.buffer[self.current_pos:self.buffer_size] = bytes[byte_pos:new_byte_pos]
+ write_int(self.buffer_size, self.wrapped)
+ self.wrapped.write(self.buffer)
+ byte_remaining -= space_left
+ byte_pos = new_byte_pos
+ self.current_pos = 0
+
+ def close(self):
+ # if there is anything left in the buffer, write it out first
+ if self.current_pos > 0:
+ write_int(self.current_pos, self.wrapped)
+ self.wrapped.write(self.buffer[:self.current_pos])
+ # -1 length indicates to the receiving end that we're done.
+ write_int(-1, self.wrapped)
+ self.wrapped.close()
+
+
if __name__ == '__main__':
import doctest
(failure_count, test_count) = doctest.testmod()
http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/test_broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/test_broadcast.py
new file mode 100644
index 0000000..ce7ca83
--- /dev/null
+++ b/python/pyspark/test_broadcast.py
@@ -0,0 +1,126 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import random
+import tempfile
+import unittest
+
+try:
+ import xmlrunner
+except ImportError:
+ xmlrunner = None
+
+from pyspark.broadcast import Broadcast
+from pyspark.conf import SparkConf
+from pyspark.context import SparkContext
+from pyspark.java_gateway import launch_gateway
+from pyspark.serializers import ChunkedStream
+
+
+class BroadcastTest(unittest.TestCase):
+
+ def tearDown(self):
+ if getattr(self, "sc", None) is not None:
+ self.sc.stop()
+ self.sc = None
+
+ def _test_encryption_helper(self, vs):
+ """
+ Creates a broadcast variables for each value in vs, and runs a simple job to make sure the
+ value is the same when it's read in the executors. Also makes sure there are no task
+ failures.
+ """
+ bs = [self.sc.broadcast(value=v) for v in vs]
+ exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect()
+ for ev in exec_values:
+ self.assertEqual(ev, vs)
+ # make sure there are no task failures
+ status = self.sc.statusTracker()
+ for jid in status.getJobIdsForGroup():
+ for sid in status.getJobInfo(jid).stageIds:
+ stage_info = status.getStageInfo(sid)
+ self.assertEqual(0, stage_info.numFailedTasks)
+
+ def _test_multiple_broadcasts(self, *extra_confs):
+ """
+ Test broadcast variables make it OK to the executors. Tests multiple broadcast variables,
+ and also multiple jobs.
+ """
+ conf = SparkConf()
+ for key, value in extra_confs:
+ conf.set(key, value)
+ conf.setMaster("local-cluster[2,1,1024]")
+ self.sc = SparkContext(conf=conf)
+ self._test_encryption_helper([5])
+ self._test_encryption_helper([5, 10, 20])
+
+ def test_broadcast_with_encryption(self):
+ self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true"))
+
+ def test_broadcast_no_encryption(self):
+ self._test_multiple_broadcasts()
+
+
+class BroadcastFrameProtocolTest(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ gateway = launch_gateway(SparkConf())
+ cls._jvm = gateway.jvm
+ cls.longMessage = True
+ random.seed(42)
+
+ def _test_chunked_stream(self, data, py_buf_size):
+ # write data using the chunked protocol from python.
+ chunked_file = tempfile.NamedTemporaryFile(delete=False)
+ dechunked_file = tempfile.NamedTemporaryFile(delete=False)
+ dechunked_file.close()
+ try:
+ out = ChunkedStream(chunked_file, py_buf_size)
+ out.write(data)
+ out.close()
+ # now try to read it in java
+ jin = self._jvm.java.io.FileInputStream(chunked_file.name)
+ jout = self._jvm.java.io.FileOutputStream(dechunked_file.name)
+ self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout)
+ # java should have decoded it back to the original data
+ self.assertEqual(len(data), os.stat(dechunked_file.name).st_size)
+ with open(dechunked_file.name, "rb") as f:
+ byte = f.read(1)
+ idx = 0
+ while byte:
+ self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx))
+ byte = f.read(1)
+ idx += 1
+ finally:
+ os.unlink(chunked_file.name)
+ os.unlink(dechunked_file.name)
+
+ def test_chunked_stream(self):
+ def random_bytes(n):
+ return bytearray(random.getrandbits(8) for _ in range(n))
+ for data_length in [1, 10, 100, 10000]:
+ for buffer_length in [1, 2, 5, 8192]:
+ self._test_chunked_stream(random_bytes(data_length), buffer_length)
+
+if __name__ == '__main__':
+ from pyspark.test_broadcast import *
+ if xmlrunner:
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
+ else:
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/test_serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py
new file mode 100644
index 0000000..5064e9f
--- /dev/null
+++ b/python/pyspark/test_serializers.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import io
+import math
+import struct
+import sys
+import unittest
+
+try:
+ import xmlrunner
+except ImportError:
+ xmlrunner = None
+
+from pyspark import serializers
+
+
+def read_int(b):
+ return struct.unpack("!i", b)[0]
+
+
+def write_int(i):
+ return struct.pack("!i", i)
+
+
+class SerializersTest(unittest.TestCase):
+
+ def test_chunked_stream(self):
+ original_bytes = bytearray(range(100))
+ for data_length in [1, 10, 100]:
+ for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]:
+ dest = ByteArrayOutput()
+ stream_out = serializers.ChunkedStream(dest, buffer_length)
+ stream_out.write(original_bytes[:data_length])
+ stream_out.close()
+ num_chunks = int(math.ceil(float(data_length) / buffer_length))
+ # length for each chunk, and a final -1 at the very end
+ exp_size = (num_chunks + 1) * 4 + data_length
+ self.assertEqual(len(dest.buffer), exp_size)
+ dest_pos = 0
+ data_pos = 0
+ for chunk_idx in range(num_chunks):
+ chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)])
+ if chunk_idx == num_chunks - 1:
+ exp_length = data_length % buffer_length
+ if exp_length == 0:
+ exp_length = buffer_length
+ else:
+ exp_length = buffer_length
+ self.assertEqual(chunk_length, exp_length)
+ dest_pos += 4
+ dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length]
+ orig_chunk = original_bytes[data_pos:data_pos + chunk_length]
+ self.assertEqual(dest_chunk, orig_chunk)
+ dest_pos += chunk_length
+ data_pos += chunk_length
+ # ends with a -1
+ self.assertEqual(dest.buffer[-4:], write_int(-1))
+
+
+class ByteArrayOutput(object):
+ def __init__(self):
+ self.buffer = bytearray()
+
+ def write(self, b):
+ self.buffer += b
+
+ def close(self):
+ pass
+
+if __name__ == '__main__':
+ from pyspark.test_serializers import *
+ if xmlrunner:
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
+ else:
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index f3cb6ae..74f194f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,7 +30,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.java_gateway import local_connect_and_auth
from pyspark.taskcontext import TaskContext
from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, write_int, read_long, \
+from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
from pyspark import shuffle
@@ -149,16 +149,34 @@ def main(infile, outfile):
importlib.invalidate_caches()
# fetch names and values of broadcast variables
+ needs_broadcast_decryption_server = read_bool(infile)
num_broadcast_variables = read_int(infile)
+ if needs_broadcast_decryption_server:
+ # read the decrypted data from a server in the jvm
+ port = read_int(infile)
+ auth_secret = utf8_deserializer.loads(infile)
+ (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)
+
for _ in range(num_broadcast_variables):
bid = read_long(infile)
if bid >= 0:
- path = utf8_deserializer.loads(infile)
- _broadcastRegistry[bid] = Broadcast(path=path)
+ if needs_broadcast_decryption_server:
+ read_bid = read_long(broadcast_sock_file)
+ assert(read_bid == bid)
+ _broadcastRegistry[bid] = \
+ Broadcast(sock_file=broadcast_sock_file)
+ else:
+ path = utf8_deserializer.loads(infile)
+ _broadcastRegistry[bid] = Broadcast(path=path)
+
else:
bid = - bid - 1
_broadcastRegistry.pop(bid)
+ if needs_broadcast_decryption_server:
+ broadcast_sock_file.write(b'1')
+ broadcast_sock_file.close()
+
_accumulatorRegistry.clear()
is_sql_udf = read_int(infile)
if is_sql_udf:
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org