You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ir...@apache.org on 2018/09/25 16:57:49 UTC

[2/3] spark git commit: [PYSPARK] Updates to pyspark broadcast

[PYSPARK] Updates to pyspark broadcast

(cherry picked from commit 09dd34cb1706f2477a89174d6a1a0f17ed5b0a65)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dd0e7cf5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dd0e7cf5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dd0e7cf5

Branch: refs/heads/branch-2.2
Commit: dd0e7cf5287148618404593ca095dd900b6e993f
Parents: fc1c4e7
Author: Imran Rashid <ir...@cloudera.com>
Authored: Mon Aug 13 21:35:34 2018 -0500
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Tue Sep 25 11:46:03 2018 -0500

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 349 ++++++++++++++++---
 .../spark/api/python/PythonRDDSuite.scala       |  23 +-
 dev/sparktestsupport/modules.py                 |   2 +
 python/pyspark/broadcast.py                     |  58 ++-
 python/pyspark/context.py                       |  63 +++-
 python/pyspark/serializers.py                   |  58 +++
 python/pyspark/test_broadcast.py                | 126 +++++++
 python/pyspark/test_serializers.py              |  90 +++++
 python/pyspark/worker.py                        |  24 +-
 9 files changed, 705 insertions(+), 88 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 7b5a179..2f4e3bc 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -24,8 +24,10 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
 import scala.language.existentials
-import scala.util.control.NonFatal
+import scala.util.Try
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.io.compress.CompressionCodec
@@ -37,6 +39,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.input.PortableDataStream
 import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.security.SocketAuthHelper
 import org.apache.spark.util._
@@ -293,19 +296,51 @@ private[spark] class PythonRunner(
         val newBids = broadcastVars.map(_.id).toSet
         // number of different broadcasts
         val toRemove = oldBids.diff(newBids)
-        val cnt = toRemove.size + newBids.diff(oldBids).size
+        val addedBids = newBids.diff(oldBids)
+        val cnt = toRemove.size + addedBids.size
+        val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty
+        dataOut.writeBoolean(needsDecryptionServer)
         dataOut.writeInt(cnt)
-        for (bid <- toRemove) {
-          // remove the broadcast from worker
-          dataOut.writeLong(- bid - 1)  // bid >= 0
-          oldBids.remove(bid)
+        def sendBidsToRemove(): Unit = {
+          for (bid <- toRemove) {
+            // remove the broadcast from worker
+            dataOut.writeLong(-bid - 1) // bid >= 0
+            oldBids.remove(bid)
+          }
         }
-        for (broadcast <- broadcastVars) {
-          if (!oldBids.contains(broadcast.id)) {
+        if (needsDecryptionServer) {
+          // if there is encryption, we setup a server which reads the encrypted files, and sends
+            // the decrypted data to python
+          val idsAndFiles = broadcastVars.flatMap { broadcast =>
+              if (oldBids.contains(broadcast.id)) {
+                  None
+                } else {
+                  Some((broadcast.id, broadcast.value.path))
+                }
+            }
+          val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
+          dataOut.writeInt(server.port)
+          logTrace(s"broadcast decryption server setup on ${server.port}")
+          PythonRDD.writeUTF(server.secret, dataOut)
+          sendBidsToRemove()
+          idsAndFiles.foreach { case (id, _) =>
             // send new broadcast
-            dataOut.writeLong(broadcast.id)
-            PythonRDD.writeUTF(broadcast.value.path, dataOut)
-            oldBids.add(broadcast.id)
+            dataOut.writeLong(id)
+            oldBids.add(id)
+          }
+          dataOut.flush()
+          logTrace("waiting for python to read decrypted broadcast data from server")
+          server.waitTillBroadcastDataSent()
+          logTrace("done sending decrypted data to python")
+        } else {
+          sendBidsToRemove()
+          for (broadcast <- broadcastVars) {
+            if (!oldBids.contains(broadcast.id)) {
+              // send new broadcast
+              dataOut.writeLong(broadcast.id)
+              PythonRDD.writeUTF(broadcast.value.path, dataOut)
+              oldBids.add(broadcast.id)
+            }
           }
         }
         dataOut.flush()
@@ -482,27 +517,34 @@ private[spark] object PythonRDD extends Logging {
 
   def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
-    val file = new DataInputStream(new FileInputStream(filename))
+    readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism)
+  }
+
+  def readRDDFromInputStream(
+      sc: SparkContext,
+      in: InputStream,
+      parallelism: Int): JavaRDD[Array[Byte]] = {
+    val din = new DataInputStream(in)
     try {
       val objs = new mutable.ArrayBuffer[Array[Byte]]
       try {
         while (true) {
-          val length = file.readInt()
+          val length = din.readInt()
           val obj = new Array[Byte](length)
-          file.readFully(obj)
+          din.readFully(obj)
           objs += obj
         }
       } catch {
         case eof: EOFException => // No-op
       }
-      JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+      JavaRDD.fromRDD(sc.parallelize(objs, parallelism))
     } finally {
-      file.close()
+      din.close()
     }
   }
 
-  def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = {
-    sc.broadcast(new PythonBroadcast(path))
+  def setupBroadcast(path: String): PythonBroadcast = {
+    new PythonBroadcast(path)
   }
 
   def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -712,34 +754,15 @@ private[spark] object PythonRDD extends Logging {
    *         data collected from this job, and the secret for authentication.
    */
   def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
-    val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
-    // Close the socket if no connection in 15 seconds
-    serverSocket.setSoTimeout(15000)
-
-    new Thread(threadName) {
-      setDaemon(true)
-      override def run() {
-        try {
-          val sock = serverSocket.accept()
-          authHelper.authClient(sock)
-
-          val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
-          Utils.tryWithSafeFinally {
-            writeIteratorToStream(items, out)
-          } {
-            out.close()
-            sock.close()
-          }
-        } catch {
-          case NonFatal(e) =>
-            logError(s"Error while sending iterator", e)
-        } finally {
-          serverSocket.close()
-        }
+    val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
+      val out = new DataOutputStream(new BufferedOutputStream(s.getOutputStream()))
+      Utils.tryWithSafeFinally {
+        writeIteratorToStream(items, out)
+      } {
+        out.close()
       }
-    }.start()
-
-    Array(serverSocket.getLocalPort, authHelper.secret)
+    }
+    Array(port, secret)
   }
 
   private def getMergedConf(confAsMap: java.util.HashMap[String, String],
@@ -957,13 +980,11 @@ private[spark] class PythonAccumulatorV2(
   }
 }
 
-/**
- * A Wrapper for Python Broadcast, which is written into disk by Python. It also will
- * write the data into disk after deserialization, then Python can read it from disks.
- */
 // scalastyle:off no.finalize
 private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
-  with Logging {
+    with Logging {
+
+  private var encryptionServer: PythonServer[Unit] = null
 
   /**
    * Read data from disks, then copy it to `out`
@@ -1005,5 +1026,233 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
       }
     }
   }
+
+  def setupEncryptionServer(): Array[Any] = {
+    encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") {
+      override def handleConnection(sock: Socket): Unit = {
+        val env = SparkEnv.get
+        val in = sock.getInputStream()
+        val dir = new File(Utils.getLocalDir(env.conf))
+        val file = File.createTempFile("broadcast", "", dir)
+        path = file.getAbsolutePath
+        val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path))
+        DechunkedInputStream.dechunkAndCopyToOutput(in, out)
+      }
+    }
+    Array(encryptionServer.port, encryptionServer.secret)
+  }
+
+  def waitTillDataReceived(): Unit = encryptionServer.getResult()
 }
 // scalastyle:on no.finalize
+
+/**
+ * The inverse of pyspark's ChunkedStream for sending broadcast data.
+ * Tested from python tests.
+ */
+private[spark] class DechunkedInputStream(wrapped: InputStream) extends InputStream with Logging {
+  private val din = new DataInputStream(wrapped)
+  private var remainingInChunk = din.readInt()
+
+  override def read(): Int = {
+    val into = new Array[Byte](1)
+    val n = read(into, 0, 1)
+    if (n == -1) {
+      -1
+    } else {
+      // if you just cast a byte to an int, then anything > 127 is negative, which is interpreted
+      // as an EOF
+      into(0) & 0xFF
+    }
+  }
+
+  override def read(dest: Array[Byte], off: Int, len: Int): Int = {
+    if (remainingInChunk == -1) {
+      return -1
+    }
+    var destSpace = len
+    var destPos = off
+    while (destSpace > 0 && remainingInChunk != -1) {
+      val toCopy = math.min(remainingInChunk, destSpace)
+      val read = din.read(dest, destPos, toCopy)
+      destPos += read
+      destSpace -= read
+      remainingInChunk -= read
+      if (remainingInChunk == 0) {
+        remainingInChunk = din.readInt()
+      }
+    }
+    assert(destSpace == 0 || remainingInChunk == -1)
+    return destPos - off
+  }
+
+  override def close(): Unit = wrapped.close()
+}
+
+/**
+ * The inverse of pyspark's ChunkedStream for sending data of unknown size.
+ *
+ * We might be serializing a really large object from python -- we don't want
+ * python to buffer the whole thing in memory, nor can it write to a file,
+ * so we don't know the length in advance.  So python writes it in chunks, each chunk
+ * preceeded by a length, till we get a "length" of -1 which serves as EOF.
+ *
+ * Tested from python tests.
+ */
+private[spark] object DechunkedInputStream {
+
+  /**
+   * Dechunks the input, copies to output, and closes both input and the output safely.
+   */
+  def dechunkAndCopyToOutput(chunked: InputStream, out: OutputStream): Unit = {
+    val dechunked = new DechunkedInputStream(chunked)
+    Utils.tryWithSafeFinally {
+      Utils.copyStream(dechunked, out)
+    } {
+      JavaUtils.closeQuietly(out)
+      JavaUtils.closeQuietly(dechunked)
+    }
+  }
+}
+
+/**
+ * Creates a server in the jvm to communicate with python for handling one batch of data, with
+ * authentication and error handling.
+ */
+private[spark] abstract class PythonServer[T](
+    authHelper: SocketAuthHelper,
+    threadName: String) {
+
+  def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
+  def this(threadName: String) = this(SparkEnv.get, threadName)
+
+  val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock =>
+    promise.complete(Try(handleConnection(sock)))
+  }
+
+  /**
+   * Handle a connection which has already been authenticated.  Any error from this function
+   * will clean up this connection and the entire server, and get propogated to [[getResult]].
+   */
+  def handleConnection(sock: Socket): T
+
+  val promise = Promise[T]()
+
+  /**
+   * Blocks indefinitely for [[handleConnection]] to finish, and returns that result.  If
+   * handleConnection throws an exception, this will throw an exception which includes the original
+   * exception as a cause.
+   */
+  def getResult(): T = {
+    getResult(Duration.Inf)
+  }
+
+  def getResult(wait: Duration): T = {
+    ThreadUtils.awaitResult(promise.future, wait)
+  }
+
+}
+
+private[spark] object PythonServer {
+
+  /**
+   * Create a socket server and run user function on the socket in a background thread.
+   *
+   * The socket server can only accept one connection, or close if no connection
+   * in 15 seconds.
+   *
+   * The thread will terminate after the supplied user function, or if there are any exceptions.
+   *
+   * If you need to get a result of the supplied function, create a subclass of [[PythonServer]]
+   *
+   * @return The port number of a local socket and the secret for authentication.
+   */
+  def setupOneConnectionServer(
+      authHelper: SocketAuthHelper,
+      threadName: String)
+      (func: Socket => Unit): (Int, String) = {
+    val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
+    // Close the socket if no connection in 15 seconds
+    serverSocket.setSoTimeout(15000)
+
+    new Thread(threadName) {
+      setDaemon(true)
+      override def run(): Unit = {
+        var sock: Socket = null
+        try {
+          sock = serverSocket.accept()
+          authHelper.authClient(sock)
+          func(sock)
+        } finally {
+          JavaUtils.closeQuietly(serverSocket)
+          JavaUtils.closeQuietly(sock)
+        }
+      }
+    }.start()
+    (serverSocket.getLocalPort, authHelper.secret)
+  }
+}
+
+/**
+ * Sends decrypted broadcast data to python worker.  See [[PythonRunner]] for entire protocol.
+ */
+private[spark] class EncryptedPythonBroadcastServer(
+    val env: SparkEnv,
+    val idsAndFiles: Seq[(Long, String)])
+    extends PythonServer[Unit]("broadcast-decrypt-server") with Logging {
+
+  override def handleConnection(socket: Socket): Unit = {
+    val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()))
+    var socketIn: InputStream = null
+    // send the broadcast id, then the decrypted data.  We don't need to send the length, the
+    // the python pickle module just needs a stream.
+    Utils.tryWithSafeFinally {
+      (idsAndFiles).foreach { case (id, path) =>
+        out.writeLong(id)
+        val in = env.serializerManager.wrapForEncryption(new FileInputStream(path))
+        Utils.tryWithSafeFinally {
+          Utils.copyStream(in, out, false)
+        } {
+          in.close()
+        }
+      }
+      logTrace("waiting for python to accept broadcast data over socket")
+      out.flush()
+      socketIn = socket.getInputStream()
+      socketIn.read()
+      logTrace("done serving broadcast data")
+    } {
+      JavaUtils.closeQuietly(socketIn)
+      JavaUtils.closeQuietly(out)
+    }
+  }
+
+  def waitTillBroadcastDataSent(): Unit = {
+    getResult()
+  }
+}
+
+/**
+ * Helper for making RDD[Array[Byte]] from some python data, by reading the data from python
+ * over a socket.  This is used in preference to writing data to a file when encryption is enabled.
+ */
+private[spark] abstract class PythonRDDServer
+    extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
+
+  def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
+    val in = sock.getInputStream()
+    val dechunkedInput: InputStream = new DechunkedInputStream(in)
+    streamToRDD(dechunkedInput)
+  }
+
+  protected def streamToRDD(input: InputStream): RDD[Array[Byte]]
+
+}
+
+private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
+    extends PythonRDDServer {
+
+  override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
+    PythonRDD.readRDDFromInputStream(sc, input, parallelism)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 05b4e67..6f9b583 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -18,9 +18,13 @@
 package org.apache.spark.api.python
 
 import java.io.{ByteArrayOutputStream, DataOutputStream}
+import java.net.{InetAddress, Socket}
 import java.nio.charset.StandardCharsets
 
-import org.apache.spark.SparkFunSuite
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.security.SocketAuthHelper
 
 class PythonRDDSuite extends SparkFunSuite {
 
@@ -44,4 +48,21 @@ class PythonRDDSuite extends SparkFunSuite {
       ("a".getBytes(StandardCharsets.UTF_8), null),
       (null, "b".getBytes(StandardCharsets.UTF_8))), buffer)
   }
+
+  test("python server error handling") {
+    val authHelper = new SocketAuthHelper(new SparkConf())
+    val errorServer = new ExceptionPythonServer(authHelper)
+    val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
+    authHelper.authToServer(client)
+    val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
+    assert(ex.getCause().getMessage().contains("exception within handleConnection"))
+  }
+
+  class ExceptionPythonServer(authHelper: SocketAuthHelper)
+      extends PythonServer[Unit](authHelper, "error-server") {
+
+    override def handleConnection(sock: Socket): Unit = {
+      throw new Exception("exception within handleConnection")
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/dev/sparktestsupport/modules.py
----------------------------------------------------------------------
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2971e0d..b7683f9 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -353,6 +353,8 @@ pyspark_core = Module(
         "pyspark.profiler",
         "pyspark.shuffle",
         "pyspark.tests",
+        "pyspark.test_broadcast",
+        "pyspark.test_serializers",
         "pyspark.util",
     ]
 )

http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 02fc515..3f1298e 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -15,13 +15,16 @@
 # limitations under the License.
 #
 
+import gc
 import os
+import socket
 import sys
-import gc
 from tempfile import NamedTemporaryFile
 import threading
 
 from pyspark.cloudpickle import print_exec
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import ChunkedStream
 from pyspark.util import _exception_message
 
 if sys.version < '3':
@@ -64,19 +67,43 @@ class Broadcast(object):
     >>> large_broadcast = sc.broadcast(range(10000))
     """
 
-    def __init__(self, sc=None, value=None, pickle_registry=None, path=None):
+    def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
+                 sock_file=None):
         """
         Should not be called directly by users -- use L{SparkContext.broadcast()}
         instead.
         """
         if sc is not None:
+            # we're on the driver.  We want the pickled data to end up in a file (maybe encrypted)
             f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
-            self._path = self.dump(value, f)
-            self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path)
+            self._path = f.name
+            python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
+            if sc._encryption_enabled:
+                # with encryption, we ask the jvm to do the encryption for us, we send it data
+                # over a socket
+                port, auth_secret = python_broadcast.setupEncryptionServer()
+                (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
+                broadcast_out = ChunkedStream(encryption_sock_file, 8192)
+            else:
+                # no encryption, we can just write pickled data directly to the file from python
+                broadcast_out = f
+            self.dump(value, broadcast_out)
+            if sc._encryption_enabled:
+                python_broadcast.waitTillDataReceived()
+            self._jbroadcast = sc._jsc.broadcast(python_broadcast)
             self._pickle_registry = pickle_registry
         else:
+            # we're on an executor
             self._jbroadcast = None
-            self._path = path
+            if sock_file is not None:
+                # the jvm is doing decryption for us.  Read the value
+                # immediately from the sock_file
+                self._value = self.load(sock_file)
+            else:
+                # the jvm just dumps the pickled data in path -- we'll unpickle lazily when
+                # the value is requested
+                assert(path is not None)
+                self._path = path
 
     def dump(self, value, f):
         try:
@@ -89,24 +116,25 @@ class Broadcast(object):
             print_exec(sys.stderr)
             raise pickle.PicklingError(msg)
         f.close()
-        return f.name
 
-    def load(self, path):
+    def load_from_path(self, path):
         with open(path, 'rb', 1 << 20) as f:
-            # pickle.load() may create lots of objects, disable GC
-            # temporary for better performance
-            gc.disable()
-            try:
-                return pickle.load(f)
-            finally:
-                gc.enable()
+            return self.load(f)
+
+    def load(self, file):
+        # "file" could also be a socket
+        gc.disable()
+        try:
+            return pickle.load(file)
+        finally:
+            gc.enable()
 
     @property
     def value(self):
         """ Return the broadcasted value
         """
         if not hasattr(self, "_value") and self._path is not None:
-            self._value = self.load(self._path)
+            self._value = self.load_from_path(self._path)
         return self._value
 
     def unpersist(self, blocking=False):

http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 02b1b2d..68e4c17 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -33,9 +33,9 @@ from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
 from pyspark.conf import SparkConf
 from pyspark.files import SparkFiles
-from pyspark.java_gateway import launch_gateway
+from pyspark.java_gateway import launch_gateway, local_connect_and_auth
 from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
-    PairDeserializer, AutoBatchedSerializer, NoOpSerializer
+    PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
 from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
 from pyspark.traceback_utils import CallSite, first_spark_call
@@ -189,6 +189,13 @@ class SparkContext(object):
         self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
         self._jsc.sc().register(self._javaAccumulator)
 
+        # If encryption is enabled, we need to setup a server in the jvm to read broadcast
+        # data via a socket.
+        # scala's mangled names w/ $ in them require special treatment.
+        encryption_conf = self._jvm.org.apache.spark.internal.config.__getattr__("package$")\
+            .__getattr__("MODULE$").IO_ENCRYPTION_ENABLED()
+        self._encryption_enabled = self._jsc.sc().conf().get(encryption_conf)
+
         self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
         self.pythonVer = "%d.%d" % sys.version_info[:2]
 
@@ -479,25 +486,43 @@ class SparkContext(object):
                 return xrange(getStart(split), getStart(split + 1), step)
 
             return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
-        # Calling the Java parallelize() method with an ArrayList is too slow,
-        # because it sends O(n) Py4J commands.  As an alternative, serialized
-        # objects are written to a file and loaded through textFile().
-        tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
-        try:
-            # Make sure we distribute data evenly if it's smaller than self.batchSize
-            if "__len__" not in dir(c):
-                c = list(c)    # Make it a list so we can compute its length
-            batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
-            serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
-            serializer.dump_stream(c, tempFile)
-            tempFile.close()
-            readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
-            jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
-        finally:
-            # readRDDFromFile eagerily reads the file so we can delete right after.
-            os.unlink(tempFile.name)
+
+        # Make sure we distribute data evenly if it's smaller than self.batchSize
+        if "__len__" not in dir(c):
+            c = list(c)    # Make it a list so we can compute its length
+        batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
+        serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
+        jrdd = self._serialize_to_jvm(c, numSlices, serializer)
         return RDD(jrdd, self, serializer)
 
+    def _serialize_to_jvm(self, data, parallelism, serializer):
+        """
+        Using py4j to send a large dataset to the jvm is really slow, so we use either a file
+        or a socket if we have encryption enabled.
+        """
+        if self._encryption_enabled:
+            # with encryption, we open a server in java and send the data directly
+            server = self._jvm.PythonParallelizeServer(self._jsc.sc(), parallelism)
+            (sock_file, _) = local_connect_and_auth(server.port(), server.secret())
+            chunked_out = ChunkedStream(sock_file, 8192)
+            serializer.dump_stream(data, chunked_out)
+            chunked_out.close()
+            # this call will block until the server has read all the data and processed it (or
+            # throws an exception)
+            return server.getResult()
+        else:
+            # without encryption, we serialize to a file, and we read the file in java and
+            # parallelize from there.
+            tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
+            try:
+                serializer.dump_stream(data, tempFile)
+                tempFile.close()
+                readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+                return readRDDFromFile(self._jsc, tempFile.name, parallelism)
+            finally:
+                # we eagerly read the file so we can delete right after.
+                os.unlink(tempFile.name)
+
     def pickleFile(self, name, minPartitions=None):
         """
         Load an RDD previously saved using L{RDD.saveAsPickleFile} method.

http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 9bd4e55..d351b506 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -586,11 +586,69 @@ def write_int(value, stream):
     stream.write(struct.pack("!i", value))
 
 
+def read_bool(stream):
+    length = stream.read(1)
+    if not length:
+        raise EOFError
+    return struct.unpack("!?", length)[0]
+
+
 def write_with_length(obj, stream):
     write_int(len(obj), stream)
     stream.write(obj)
 
 
+class ChunkedStream(object):
+
+    """
+    This file-like object takes a stream of data, of unknown length, and breaks it into fixed
+    length frames.  The intended use case is serializing large data and sending it immediately over
+    a socket -- we do not want to buffer the entire data before sending it, but the receiving end
+    needs to know whether or not there is more data coming.
+
+    It works by buffering the incoming data in some fixed-size chunks.  If the buffer is full, it
+    first sends the buffer size, then the data.  This repeats as long as there is more data to send.
+    When this is closed, it sends the length of whatever data is in the buffer, then that data, and
+    finally a "length" of -1 to indicate the stream has completed.
+    """
+
+    def __init__(self, wrapped, buffer_size):
+        self.buffer_size = buffer_size
+        self.buffer = bytearray(buffer_size)
+        self.current_pos = 0
+        self.wrapped = wrapped
+
+    def write(self, bytes):
+        byte_pos = 0
+        byte_remaining = len(bytes)
+        while byte_remaining > 0:
+            new_pos = byte_remaining + self.current_pos
+            if new_pos < self.buffer_size:
+                # just put it in our buffer
+                self.buffer[self.current_pos:new_pos] = bytes[byte_pos:]
+                self.current_pos = new_pos
+                byte_remaining = 0
+            else:
+                # fill the buffer, send the length then the contents, and start filling again
+                space_left = self.buffer_size - self.current_pos
+                new_byte_pos = byte_pos + space_left
+                self.buffer[self.current_pos:self.buffer_size] = bytes[byte_pos:new_byte_pos]
+                write_int(self.buffer_size, self.wrapped)
+                self.wrapped.write(self.buffer)
+                byte_remaining -= space_left
+                byte_pos = new_byte_pos
+                self.current_pos = 0
+
+    def close(self):
+        # if there is anything left in the buffer, write it out first
+        if self.current_pos > 0:
+            write_int(self.current_pos, self.wrapped)
+            self.wrapped.write(self.buffer[:self.current_pos])
+        # -1 length indicates to the receiving end that we're done.
+        write_int(-1, self.wrapped)
+        self.wrapped.close()
+
+
 if __name__ == '__main__':
     import doctest
     (failure_count, test_count) = doctest.testmod()

http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/test_broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/test_broadcast.py
new file mode 100644
index 0000000..ce7ca83
--- /dev/null
+++ b/python/pyspark/test_broadcast.py
@@ -0,0 +1,126 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import random
+import tempfile
+import unittest
+
+try:
+    import xmlrunner
+except ImportError:
+    xmlrunner = None
+
+from pyspark.broadcast import Broadcast
+from pyspark.conf import SparkConf
+from pyspark.context import SparkContext
+from pyspark.java_gateway import launch_gateway
+from pyspark.serializers import ChunkedStream
+
+
+class BroadcastTest(unittest.TestCase):
+
+    def tearDown(self):
+        if getattr(self, "sc", None) is not None:
+            self.sc.stop()
+            self.sc = None
+
+    def _test_encryption_helper(self, vs):
+        """
+        Creates a broadcast variables for each value in vs, and runs a simple job to make sure the
+        value is the same when it's read in the executors.  Also makes sure there are no task
+        failures.
+        """
+        bs = [self.sc.broadcast(value=v) for v in vs]
+        exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect()
+        for ev in exec_values:
+            self.assertEqual(ev, vs)
+        # make sure there are no task failures
+        status = self.sc.statusTracker()
+        for jid in status.getJobIdsForGroup():
+            for sid in status.getJobInfo(jid).stageIds:
+                stage_info = status.getStageInfo(sid)
+                self.assertEqual(0, stage_info.numFailedTasks)
+
+    def _test_multiple_broadcasts(self, *extra_confs):
+        """
+        Test broadcast variables make it OK to the executors.  Tests multiple broadcast variables,
+        and also multiple jobs.
+        """
+        conf = SparkConf()
+        for key, value in extra_confs:
+            conf.set(key, value)
+        conf.setMaster("local-cluster[2,1,1024]")
+        self.sc = SparkContext(conf=conf)
+        self._test_encryption_helper([5])
+        self._test_encryption_helper([5, 10, 20])
+
+    def test_broadcast_with_encryption(self):
+        self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true"))
+
+    def test_broadcast_no_encryption(self):
+        self._test_multiple_broadcasts()
+
+
+class BroadcastFrameProtocolTest(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        gateway = launch_gateway(SparkConf())
+        cls._jvm = gateway.jvm
+        cls.longMessage = True
+        random.seed(42)
+
+    def _test_chunked_stream(self, data, py_buf_size):
+        # write data using the chunked protocol from python.
+        chunked_file = tempfile.NamedTemporaryFile(delete=False)
+        dechunked_file = tempfile.NamedTemporaryFile(delete=False)
+        dechunked_file.close()
+        try:
+            out = ChunkedStream(chunked_file, py_buf_size)
+            out.write(data)
+            out.close()
+            # now try to read it in java
+            jin = self._jvm.java.io.FileInputStream(chunked_file.name)
+            jout = self._jvm.java.io.FileOutputStream(dechunked_file.name)
+            self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout)
+            # java should have decoded it back to the original data
+            self.assertEqual(len(data), os.stat(dechunked_file.name).st_size)
+            with open(dechunked_file.name, "rb") as f:
+                byte = f.read(1)
+                idx = 0
+                while byte:
+                    self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx))
+                    byte = f.read(1)
+                    idx += 1
+        finally:
+            os.unlink(chunked_file.name)
+            os.unlink(dechunked_file.name)
+
+    def test_chunked_stream(self):
+        def random_bytes(n):
+            return bytearray(random.getrandbits(8) for _ in range(n))
+        for data_length in [1, 10, 100, 10000]:
+            for buffer_length in [1, 2, 5, 8192]:
+                self._test_chunked_stream(random_bytes(data_length), buffer_length)
+
+if __name__ == '__main__':
+    from pyspark.test_broadcast import *
+    if xmlrunner:
+        unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
+    else:
+        unittest.main()

http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/test_serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py
new file mode 100644
index 0000000..5064e9f
--- /dev/null
+++ b/python/pyspark/test_serializers.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import io
+import math
+import struct
+import sys
+import unittest
+
+try:
+    import xmlrunner
+except ImportError:
+    xmlrunner = None
+
+from pyspark import serializers
+
+
+def read_int(b):
+    return struct.unpack("!i", b)[0]
+
+
+def write_int(i):
+    return struct.pack("!i", i)
+
+
+class SerializersTest(unittest.TestCase):
+
+    def test_chunked_stream(self):
+        original_bytes = bytearray(range(100))
+        for data_length in [1, 10, 100]:
+            for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]:
+                dest = ByteArrayOutput()
+                stream_out = serializers.ChunkedStream(dest, buffer_length)
+                stream_out.write(original_bytes[:data_length])
+                stream_out.close()
+                num_chunks = int(math.ceil(float(data_length) / buffer_length))
+                # length for each chunk, and a final -1 at the very end
+                exp_size = (num_chunks + 1) * 4 + data_length
+                self.assertEqual(len(dest.buffer), exp_size)
+                dest_pos = 0
+                data_pos = 0
+                for chunk_idx in range(num_chunks):
+                    chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)])
+                    if chunk_idx == num_chunks - 1:
+                        exp_length = data_length % buffer_length
+                        if exp_length == 0:
+                            exp_length = buffer_length
+                    else:
+                        exp_length = buffer_length
+                    self.assertEqual(chunk_length, exp_length)
+                    dest_pos += 4
+                    dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length]
+                    orig_chunk = original_bytes[data_pos:data_pos + chunk_length]
+                    self.assertEqual(dest_chunk, orig_chunk)
+                    dest_pos += chunk_length
+                    data_pos += chunk_length
+                # ends with a -1
+                self.assertEqual(dest.buffer[-4:], write_int(-1))
+
+
+class ByteArrayOutput(object):
+    def __init__(self):
+        self.buffer = bytearray()
+
+    def write(self, b):
+        self.buffer += b
+
+    def close(self):
+        pass
+
+if __name__ == '__main__':
+    from pyspark.test_serializers import *
+    if xmlrunner:
+        unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
+    else:
+        unittest.main()

http://git-wip-us.apache.org/repos/asf/spark/blob/dd0e7cf5/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index f3cb6ae..74f194f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,7 +30,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.taskcontext import TaskContext
 from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, write_int, read_long, \
+from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
 from pyspark import shuffle
 
@@ -149,16 +149,34 @@ def main(infile, outfile):
             importlib.invalidate_caches()
 
         # fetch names and values of broadcast variables
+        needs_broadcast_decryption_server = read_bool(infile)
         num_broadcast_variables = read_int(infile)
+        if needs_broadcast_decryption_server:
+            # read the decrypted data from a server in the jvm
+            port = read_int(infile)
+            auth_secret = utf8_deserializer.loads(infile)
+            (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)
+
         for _ in range(num_broadcast_variables):
             bid = read_long(infile)
             if bid >= 0:
-                path = utf8_deserializer.loads(infile)
-                _broadcastRegistry[bid] = Broadcast(path=path)
+                if needs_broadcast_decryption_server:
+                    read_bid = read_long(broadcast_sock_file)
+                    assert(read_bid == bid)
+                    _broadcastRegistry[bid] = \
+                        Broadcast(sock_file=broadcast_sock_file)
+                else:
+                    path = utf8_deserializer.loads(infile)
+                    _broadcastRegistry[bid] = Broadcast(path=path)
+
             else:
                 bid = - bid - 1
                 _broadcastRegistry.pop(bid)
 
+        if needs_broadcast_decryption_server:
+            broadcast_sock_file.write(b'1')
+            broadcast_sock_file.close()
+
         _accumulatorRegistry.clear()
         is_sql_udf = read_int(infile)
         if is_sql_udf:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org