You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ir...@apache.org on 2018/09/17 19:40:35 UTC

[3/3] spark git commit: [PYSPARK] Updates to pyspark broadcast

[PYSPARK] Updates to pyspark broadcast


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/58419b92
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/58419b92
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/58419b92

Branch: refs/heads/master
Commit: 58419b92673c46911c25bc6c6b13397f880c6424
Parents: 553af22
Author: Imran Rashid <ir...@cloudera.com>
Authored: Mon Aug 13 21:35:34 2018 -0500
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Mon Sep 17 14:06:09 2018 -0500

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 299 ++++++++++++++++---
 .../apache/spark/api/python/PythonRunner.scala  |  52 +++-
 .../spark/api/python/PythonRDDSuite.scala       |  23 +-
 dev/sparktestsupport/modules.py                 |   2 +
 python/pyspark/broadcast.py                     |  58 +++-
 python/pyspark/context.py                       |  64 ++--
 python/pyspark/serializers.py                   |  51 ++++
 python/pyspark/sql/session.py                   |  12 +-
 python/pyspark/sql/tests.py                     |  45 ++-
 python/pyspark/test_broadcast.py                | 126 ++++++++
 python/pyspark/test_serializers.py              |  90 ++++++
 python/pyspark/tests.py                         |   9 +-
 python/pyspark/worker.py                        |  22 +-
 .../spark/sql/api/python/PythonSQLUtils.scala   |  47 ++-
 .../sql/execution/arrow/ArrowConverters.scala   |   9 +-
 15 files changed, 789 insertions(+), 120 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index e639a84..8b5a7a9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -24,8 +24,10 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
 import scala.language.existentials
-import scala.util.control.NonFatal
+import scala.util.Try
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.io.compress.CompressionCodec
@@ -37,6 +39,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.input.PortableDataStream
 import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.security.SocketAuthHelper
 import org.apache.spark.util._
@@ -169,27 +172,34 @@ private[spark] object PythonRDD extends Logging {
 
   def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
-    val file = new DataInputStream(new FileInputStream(filename))
+    readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism)
+  }
+
+  def readRDDFromInputStream(
+      sc: SparkContext,
+      in: InputStream,
+      parallelism: Int): JavaRDD[Array[Byte]] = {
+    val din = new DataInputStream(in)
     try {
       val objs = new mutable.ArrayBuffer[Array[Byte]]
       try {
         while (true) {
-          val length = file.readInt()
+          val length = din.readInt()
           val obj = new Array[Byte](length)
-          file.readFully(obj)
+          din.readFully(obj)
           objs += obj
         }
       } catch {
         case eof: EOFException => // No-op
       }
-      JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+      JavaRDD.fromRDD(sc.parallelize(objs, parallelism))
     } finally {
-      file.close()
+      din.close()
     }
   }
 
-  def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = {
-    sc.broadcast(new PythonBroadcast(path))
+  def setupBroadcast(path: String): PythonBroadcast = {
+    new PythonBroadcast(path)
   }
 
   def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -419,34 +429,15 @@ private[spark] object PythonRDD extends Logging {
    */
   private[spark] def serveToStream(
       threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
-    val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
-    // Close the socket if no connection in 15 seconds
-    serverSocket.setSoTimeout(15000)
-
-    new Thread(threadName) {
-      setDaemon(true)
-      override def run() {
-        try {
-          val sock = serverSocket.accept()
-          authHelper.authClient(sock)
-
-          val out = new BufferedOutputStream(sock.getOutputStream)
-          Utils.tryWithSafeFinally {
-            writeFunc(out)
-          } {
-            out.close()
-            sock.close()
-          }
-        } catch {
-          case NonFatal(e) =>
-            logError(s"Error while sending iterator", e)
-        } finally {
-          serverSocket.close()
-        }
+    val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
+      val out = new BufferedOutputStream(s.getOutputStream())
+      Utils.tryWithSafeFinally {
+        writeFunc(out)
+      } {
+        out.close()
       }
-    }.start()
-
-    Array(serverSocket.getLocalPort, authHelper.secret)
+    }
+    Array(port, secret)
   }
 
   private def getMergedConf(confAsMap: java.util.HashMap[String, String],
@@ -664,13 +655,11 @@ private[spark] class PythonAccumulatorV2(
   }
 }
 
-/**
- * A Wrapper for Python Broadcast, which is written into disk by Python. It also will
- * write the data into disk after deserialization, then Python can read it from disks.
- */
 // scalastyle:off no.finalize
 private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
-  with Logging {
+    with Logging {
+
+  private var encryptionServer: PythonServer[Unit] = null
 
   /**
    * Read data from disks, then copy it to `out`
@@ -713,5 +702,235 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
     }
     super.finalize()
   }
+
+  def setupEncryptionServer(): Array[Any] = {
+    encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") {
+      override def handleConnection(sock: Socket): Unit = {
+        val env = SparkEnv.get
+        val in = sock.getInputStream()
+        val dir = new File(Utils.getLocalDir(env.conf))
+        val file = File.createTempFile("broadcast", "", dir)
+        path = file.getAbsolutePath
+        val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path))
+        DechunkedInputStream.dechunkAndCopyToOutput(in, out)
+      }
+    }
+    Array(encryptionServer.port, encryptionServer.secret)
+  }
+
+  def waitTillDataReceived(): Unit = encryptionServer.getResult()
 }
 // scalastyle:on no.finalize
+
+/**
+ * The inverse of pyspark's ChunkedStream for sending data of unknown size.
+ *
+ * We might be serializing a really large object from python -- we don't want
+ * python to buffer the whole thing in memory, nor can it write to a file,
+ * so we don't know the length in advance.  So python writes it in chunks, each chunk
+ * preceeded by a length, till we get a "length" of -1 which serves as EOF.
+ *
+ * Tested from python tests.
+ */
+private[spark] class DechunkedInputStream(wrapped: InputStream) extends InputStream with Logging {
+  private val din = new DataInputStream(wrapped)
+  private var remainingInChunk = din.readInt()
+
+  override def read(): Int = {
+    val into = new Array[Byte](1)
+    val n = read(into, 0, 1)
+    if (n == -1) {
+      -1
+    } else {
+      // if you just cast a byte to an int, then anything > 127 is negative, which is interpreted
+      // as an EOF
+      val b = into(0)
+      if (b < 0) {
+        256 + b
+      } else {
+        b
+      }
+    }
+  }
+
+  override def read(dest: Array[Byte], off: Int, len: Int): Int = {
+    if (remainingInChunk == -1) {
+      return -1
+    }
+    var destSpace = len
+    var destPos = off
+    while (destSpace > 0 && remainingInChunk != -1) {
+      val toCopy = math.min(remainingInChunk, destSpace)
+      val read = din.read(dest, destPos, toCopy)
+      destPos += read
+      destSpace -= read
+      remainingInChunk -= read
+      if (remainingInChunk == 0) {
+        remainingInChunk = din.readInt()
+      }
+    }
+    assert(destSpace == 0 || remainingInChunk == -1)
+    return destPos - off
+  }
+
+  override def close(): Unit = wrapped.close()
+}
+
+private[spark] object DechunkedInputStream {
+
+  /**
+   * Dechunks the input, copies to output, and closes both input and the output safely.
+   */
+  def dechunkAndCopyToOutput(chunked: InputStream, out: OutputStream): Unit = {
+    val dechunked = new DechunkedInputStream(chunked)
+    Utils.tryWithSafeFinally {
+      Utils.copyStream(dechunked, out)
+    } {
+      JavaUtils.closeQuietly(out)
+      JavaUtils.closeQuietly(dechunked)
+    }
+  }
+}
+
+/**
+ * Creates a server in the jvm to communicate with python for handling one batch of data, with
+ * authentication and error handling.
+ */
+private[spark] abstract class PythonServer[T](
+    authHelper: SocketAuthHelper,
+    threadName: String) {
+
+  def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
+  def this(threadName: String) = this(SparkEnv.get, threadName)
+
+  val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock =>
+    promise.complete(Try(handleConnection(sock)))
+  }
+
+  /**
+   * Handle a connection which has already been authenticated.  Any error from this function
+   * will clean up this connection and the entire server, and get propogated to [[getResult]].
+   */
+  def handleConnection(sock: Socket): T
+
+  val promise = Promise[T]()
+
+  /**
+   * Blocks indefinitely for [[handleConnection]] to finish, and returns that result.  If
+   * handleConnection throws an exception, this will throw an exception which includes the original
+   * exception as a cause.
+   */
+  def getResult(): T = {
+    getResult(Duration.Inf)
+  }
+
+  def getResult(wait: Duration): T = {
+    ThreadUtils.awaitResult(promise.future, wait)
+  }
+
+}
+
+private[spark] object PythonServer {
+
+  /**
+   * Create a socket server and run user function on the socket in a background thread.
+   *
+   * The socket server can only accept one connection, or close if no connection
+   * in 15 seconds.
+   *
+   * The thread will terminate after the supplied user function, or if there are any exceptions.
+   *
+   * If you need to get a result of the supplied function, create a subclass of [[PythonServer]]
+   *
+   * @return The port number of a local socket and the secret for authentication.
+   */
+  def setupOneConnectionServer(
+      authHelper: SocketAuthHelper,
+      threadName: String)
+      (func: Socket => Unit): (Int, String) = {
+    val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
+    // Close the socket if no connection in 15 seconds
+    serverSocket.setSoTimeout(15000)
+
+    new Thread(threadName) {
+      setDaemon(true)
+      override def run(): Unit = {
+        var sock: Socket = null
+        try {
+          sock = serverSocket.accept()
+          authHelper.authClient(sock)
+          func(sock)
+        } finally {
+          JavaUtils.closeQuietly(serverSocket)
+          JavaUtils.closeQuietly(sock)
+        }
+      }
+    }.start()
+    (serverSocket.getLocalPort, authHelper.secret)
+  }
+}
+
+/**
+ * Sends decrypted broadcast data to python worker.  See [[PythonRunner]] for entire protocol.
+ */
+private[spark] class EncryptedPythonBroadcastServer(
+    val env: SparkEnv,
+    val idsAndFiles: Seq[(Long, String)])
+    extends PythonServer[Unit]("broadcast-decrypt-server") with Logging {
+
+  override def handleConnection(socket: Socket): Unit = {
+    val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()))
+    var socketIn: InputStream = null
+    // send the broadcast id, then the decrypted data.  We don't need to send the length, the
+    // the python pickle module just needs a stream.
+    Utils.tryWithSafeFinally {
+      (idsAndFiles).foreach { case (id, path) =>
+        out.writeLong(id)
+        val in = env.serializerManager.wrapForEncryption(new FileInputStream(path))
+        Utils.tryWithSafeFinally {
+          Utils.copyStream(in, out, false)
+        } {
+          in.close()
+        }
+      }
+      logTrace("waiting for python to accept broadcast data over socket")
+      out.flush()
+      socketIn = socket.getInputStream()
+      socketIn.read()
+      logTrace("done serving broadcast data")
+    } {
+      JavaUtils.closeQuietly(socketIn)
+      JavaUtils.closeQuietly(out)
+    }
+  }
+
+  def waitTillBroadcastDataSent(): Unit = {
+    getResult()
+  }
+}
+
+/**
+ * Helper for making RDD[Array[Byte]] from some python data, by reading the data from python
+ * over a socket.  This is used in preference to writing data to a file when encryption is enabled.
+ */
+private[spark] abstract class PythonRDDServer
+    extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
+
+  def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
+    val in = sock.getInputStream()
+    val dechunkedInput: InputStream = new DechunkedInputStream(in)
+    streamToRDD(dechunkedInput)
+  }
+
+  protected def streamToRDD(input: InputStream): RDD[Array[Byte]]
+
+}
+
+private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
+    extends PythonRDDServer {
+
+  override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
+    PythonRDD.readRDDFromInputStream(sc, input, parallelism)
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 4c53bc2..6e53a04 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -289,19 +289,51 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         val newBids = broadcastVars.map(_.id).toSet
         // number of different broadcasts
         val toRemove = oldBids.diff(newBids)
-        val cnt = toRemove.size + newBids.diff(oldBids).size
+        val addedBids = newBids.diff(oldBids)
+        val cnt = toRemove.size + addedBids.size
+        val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty
+        dataOut.writeBoolean(needsDecryptionServer)
         dataOut.writeInt(cnt)
-        for (bid <- toRemove) {
-          // remove the broadcast from worker
-          dataOut.writeLong(- bid - 1)  // bid >= 0
-          oldBids.remove(bid)
+        def sendBidsToRemove(): Unit = {
+          for (bid <- toRemove) {
+            // remove the broadcast from worker
+            dataOut.writeLong(-bid - 1) // bid >= 0
+            oldBids.remove(bid)
+          }
         }
-        for (broadcast <- broadcastVars) {
-          if (!oldBids.contains(broadcast.id)) {
+        if (needsDecryptionServer) {
+          // if there is encryption, we setup a server which reads the encrypted files, and sends
+          // the decrypted data to python
+          val idsAndFiles = broadcastVars.flatMap { broadcast =>
+            if (!oldBids.contains(broadcast.id)) {
+              Some((broadcast.id, broadcast.value.path))
+            } else {
+              None
+            }
+          }
+          val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
+          dataOut.writeInt(server.port)
+          logTrace(s"broadcast decryption server setup on ${server.port}")
+          PythonRDD.writeUTF(server.secret, dataOut)
+          sendBidsToRemove()
+          idsAndFiles.foreach { case (id, _) =>
             // send new broadcast
-            dataOut.writeLong(broadcast.id)
-            PythonRDD.writeUTF(broadcast.value.path, dataOut)
-            oldBids.add(broadcast.id)
+            dataOut.writeLong(id)
+            oldBids.add(id)
+          }
+          dataOut.flush()
+          logTrace("waiting for python to read decrypted broadcast data from server")
+          server.waitTillBroadcastDataSent()
+          logTrace("done sending decrypted data to python")
+        } else {
+          sendBidsToRemove()
+          for (broadcast <- broadcastVars) {
+            if (!oldBids.contains(broadcast.id)) {
+              // send new broadcast
+              dataOut.writeLong(broadcast.id)
+              PythonRDD.writeUTF(broadcast.value.path, dataOut)
+              oldBids.add(broadcast.id)
+            }
           }
         }
         dataOut.flush()

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 05b4e67..6f9b583 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -18,9 +18,13 @@
 package org.apache.spark.api.python
 
 import java.io.{ByteArrayOutputStream, DataOutputStream}
+import java.net.{InetAddress, Socket}
 import java.nio.charset.StandardCharsets
 
-import org.apache.spark.SparkFunSuite
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.security.SocketAuthHelper
 
 class PythonRDDSuite extends SparkFunSuite {
 
@@ -44,4 +48,21 @@ class PythonRDDSuite extends SparkFunSuite {
       ("a".getBytes(StandardCharsets.UTF_8), null),
       (null, "b".getBytes(StandardCharsets.UTF_8))), buffer)
   }
+
+  test("python server error handling") {
+    val authHelper = new SocketAuthHelper(new SparkConf())
+    val errorServer = new ExceptionPythonServer(authHelper)
+    val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
+    authHelper.authToServer(client)
+    val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
+    assert(ex.getCause().getMessage().contains("exception within handleConnection"))
+  }
+
+  class ExceptionPythonServer(authHelper: SocketAuthHelper)
+      extends PythonServer[Unit](authHelper, "error-server") {
+
+    override def handleConnection(sock: Socket): Unit = {
+      throw new Exception("exception within handleConnection")
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/dev/sparktestsupport/modules.py
----------------------------------------------------------------------
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2aa3555..e267fbf 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -387,6 +387,8 @@ pyspark_core = Module(
         "pyspark.profiler",
         "pyspark.shuffle",
         "pyspark.tests",
+        "pyspark.test_broadcast",
+        "pyspark.test_serializers",
         "pyspark.util",
     ]
 )

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index b3dfc99..1c7f2a7 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -15,13 +15,16 @@
 # limitations under the License.
 #
 
+import gc
 import os
+import socket
 import sys
-import gc
 from tempfile import NamedTemporaryFile
 import threading
 
 from pyspark.cloudpickle import print_exec
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import ChunkedStream
 from pyspark.util import _exception_message
 
 if sys.version < '3':
@@ -64,19 +67,43 @@ class Broadcast(object):
     >>> large_broadcast = sc.broadcast(range(10000))
     """
 
-    def __init__(self, sc=None, value=None, pickle_registry=None, path=None):
+    def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
+                 sock_file=None):
         """
         Should not be called directly by users -- use L{SparkContext.broadcast()}
         instead.
         """
         if sc is not None:
+            # we're on the driver.  We want the pickled data to end up in a file (maybe encrypted)
             f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
-            self._path = self.dump(value, f)
-            self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path)
+            self._path = f.name
+            python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
+            if sc._encryption_enabled:
+                # with encryption, we ask the jvm to do the encryption for us, we send it data
+                # over a socket
+                port, auth_secret = python_broadcast.setupEncryptionServer()
+                (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
+                broadcast_out = ChunkedStream(encryption_sock_file, 8192)
+            else:
+                # no encryption, we can just write pickled data directly to the file from python
+                broadcast_out = f
+            self.dump(value, broadcast_out)
+            if sc._encryption_enabled:
+                python_broadcast.waitTillDataReceived()
+            self._jbroadcast = sc._jsc.broadcast(python_broadcast)
             self._pickle_registry = pickle_registry
         else:
+            # we're on an executor
             self._jbroadcast = None
-            self._path = path
+            if sock_file is not None:
+                # the jvm is doing decryption for us.  Read the value
+                # immediately from the sock_file
+                self._value = self.load(sock_file)
+            else:
+                # the jvm just dumps the pickled data in path -- we'll unpickle lazily when
+                # the value is requested
+                assert(path is not None)
+                self._path = path
 
     def dump(self, value, f):
         try:
@@ -89,24 +116,25 @@ class Broadcast(object):
             print_exec(sys.stderr)
             raise pickle.PicklingError(msg)
         f.close()
-        return f.name
 
-    def load(self, path):
+    def load_from_path(self, path):
         with open(path, 'rb', 1 << 20) as f:
-            # pickle.load() may create lots of objects, disable GC
-            # temporary for better performance
-            gc.disable()
-            try:
-                return pickle.load(f)
-            finally:
-                gc.enable()
+            return self.load(f)
+
+    def load(self, file):
+        # "file" could also be a socket
+        gc.disable()
+        try:
+            return pickle.load(file)
+        finally:
+            gc.enable()
 
     @property
     def value(self):
         """ Return the broadcasted value
         """
         if not hasattr(self, "_value") and self._path is not None:
-            self._value = self.load(self._path)
+            self._value = self.load_from_path(self._path)
         return self._value
 
     def unpersist(self, blocking=False):

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 4cabae4..2c92c29 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -33,9 +33,9 @@ from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
 from pyspark.conf import SparkConf
 from pyspark.files import SparkFiles
-from pyspark.java_gateway import launch_gateway
+from pyspark.java_gateway import launch_gateway, local_connect_and_auth
 from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
-    PairDeserializer, AutoBatchedSerializer, NoOpSerializer
+    PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
 from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
 from pyspark.traceback_utils import CallSite, first_spark_call
@@ -189,6 +189,13 @@ class SparkContext(object):
         self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
         self._jsc.sc().register(self._javaAccumulator)
 
+        # If encryption is enabled, we need to setup a server in the jvm to read broadcast
+        # data via a socket.
+        # scala's mangled names w/ $ in them require special treatment.
+        encryption_conf = self._jvm.org.apache.spark.internal.config.__getattr__("package$")\
+            .__getattr__("MODULE$").IO_ENCRYPTION_ENABLED()
+        self._encryption_enabled = self._jsc.sc().conf().get(encryption_conf)
+
         self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
         self.pythonVer = "%d.%d" % sys.version_info[:2]
 
@@ -498,23 +505,46 @@ class SparkContext(object):
         def reader_func(temp_filename):
             return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices)
 
-        jrdd = self._serialize_to_jvm(c, serializer, reader_func)
+        def createRDDServer():
+            return self._jvm.PythonParallelizeServer(self._jsc.sc(), numSlices)
+
+        jrdd = self._serialize_to_jvm(c, serializer, reader_func, createRDDServer)
         return RDD(jrdd, self, serializer)
 
-    def _serialize_to_jvm(self, data, serializer, reader_func):
-        """
-        Calling the Java parallelize() method with an ArrayList is too slow,
-        because it sends O(n) Py4J commands.  As an alternative, serialized
-        objects are written to a file and loaded through textFile().
-        """
-        tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
-        try:
-            serializer.dump_stream(data, tempFile)
-            tempFile.close()
-            return reader_func(tempFile.name)
-        finally:
-            # readRDDFromFile eagerily reads the file so we can delete right after.
-            os.unlink(tempFile.name)
+    def _serialize_to_jvm(self, data, serializer, reader_func, createRDDServer):
+        """
+        Using py4j to send a large dataset to the jvm is really slow, so we use either a file
+        or a socket if we have encryption enabled.
+        :param data:
+        :param serializer:
+        :param reader_func:  A function which takes a filename and reads in the data in the jvm and
+                returns a JavaRDD. Only used when encryption is disabled.
+        :param createRDDServer:  A function which creates a PythonRDDServer in the jvm to
+               accept the serialized data, for use when encryption is enabled.
+        :return:
+        """
+        if self._encryption_enabled:
+            # with encryption, we open a server in java and send the data directly
+            server = createRDDServer()
+            (sock_file, _) = local_connect_and_auth(server.port(), server.secret())
+            chunked_out = ChunkedStream(sock_file, 8192)
+            serializer.dump_stream(data, chunked_out)
+            chunked_out.close()
+            # this call will block until the server has read all the data and processed it (or
+            # throws an exception)
+            r = server.getResult()
+            return r
+        else:
+            # without encryption, we serialize to a file, and we read the file in java and
+            # parallelize from there.
+            tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
+            try:
+                serializer.dump_stream(data, tempFile)
+                tempFile.close()
+                return reader_func(tempFile.name)
+            finally:
+                # we eagerily reads the file so we can delete right after.
+                os.unlink(tempFile.name)
 
     def pickleFile(self, name, minPartitions=None):
         """

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 4800677..ff9a612 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -731,6 +731,57 @@ def write_with_length(obj, stream):
     stream.write(obj)
 
 
+class ChunkedStream(object):
+
+    """
+    This is a file-like object takes a stream of data, of unknown length, and breaks it into fixed
+    length frames.  The intended use case is serializing large data and sending it immediately over
+    a socket -- we do not want to buffer the entire data before sending it, but the receiving end
+    needs to know whether or not there is more data coming.
+
+    It works by buffering the incoming data in some fixed-size chunks.  If the buffer is full, it
+    first sends the buffer size, then the data.  This repeats as long as there is more data to send.
+    When this is closed, it sends the length of whatever data is in the buffer, then that data, and
+    finally a "length" of -1 to indicate the stream has completed.
+    """
+
+    def __init__(self, wrapped, buffer_size):
+        self.buffer_size = buffer_size
+        self.buffer = bytearray(buffer_size)
+        self.current_pos = 0
+        self.wrapped = wrapped
+
+    def write(self, bytes):
+        byte_pos = 0
+        byte_remaining = len(bytes)
+        while byte_remaining > 0:
+            new_pos = byte_remaining + self.current_pos
+            if new_pos < self.buffer_size:
+                # just put it in our buffer
+                self.buffer[self.current_pos:new_pos] = bytes[byte_pos:]
+                self.current_pos = new_pos
+                byte_remaining = 0
+            else:
+                # fill the buffer, send the length then the contents, and start filling again
+                space_left = self.buffer_size - self.current_pos
+                new_byte_pos = byte_pos + space_left
+                self.buffer[self.current_pos:self.buffer_size] = bytes[byte_pos:new_byte_pos]
+                write_int(self.buffer_size, self.wrapped)
+                self.wrapped.write(self.buffer)
+                byte_remaining -= space_left
+                byte_pos = new_byte_pos
+                self.current_pos = 0
+
+    def close(self):
+        # if there is anything left in the buffer, write it out first
+        if self.current_pos > 0:
+            write_int(self.current_pos, self.wrapped)
+            self.wrapped.write(self.buffer[:self.current_pos])
+        # -1 length indicates to the receiving end that we're done.
+        write_int(-1, self.wrapped)
+        self.wrapped.close()
+
+
 if __name__ == '__main__':
     import doctest
     (failure_count, test_count) = doctest.testmod()

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/sql/session.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 87d8d6a..51a38eb 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -539,12 +539,18 @@ class SparkSession(object):
                 struct.names[i] = name
             schema = struct
 
+        jsqlContext = self._wrapped._jsqlContext
+
         def reader_func(temp_filename):
-            return self._jvm.PythonSQLUtils.arrowReadStreamFromFile(
-                self._wrapped._jsqlContext, temp_filename, schema.json())
+            return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
+
+        def create_RDD_server():
+            return self._jvm.ArrowRDDServer(jsqlContext)
 
         # Create Spark DataFrame from Arrow stream file, using one batch per partition
-        jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func)
+        jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func,
+                                          create_RDD_server)
+        jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
         df = DataFrame(jdf, self._wrapped)
         df._schema = schema
         return df

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 8e5bc67..08d7cfa 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -26,6 +26,7 @@ import subprocess
 import pydoc
 import shutil
 import tempfile
+import threading
 import pickle
 import functools
 import time
@@ -228,12 +229,12 @@ class SQLTestUtils(object):
 class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
     @classmethod
     def setUpClass(cls):
-        ReusedPySparkTestCase.setUpClass()
+        super(ReusedSQLTestCase, cls).setUpClass()
         cls.spark = SparkSession(cls.sc)
 
     @classmethod
     def tearDownClass(cls):
-        ReusedPySparkTestCase.tearDownClass()
+        super(ReusedSQLTestCase, cls).tearDownClass()
         cls.spark.stop()
 
     def assertPandasEqual(self, expected, result):
@@ -4105,7 +4106,8 @@ class ArrowTests(ReusedSQLTestCase):
         from decimal import Decimal
         from distutils.version import LooseVersion
         import pyarrow as pa
-        ReusedSQLTestCase.setUpClass()
+        super(ArrowTests, cls).setUpClass()
+        cls.warnings_lock = threading.Lock()
 
         # Synchronize default timezone between Python and Java
         cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
@@ -4146,7 +4148,7 @@ class ArrowTests(ReusedSQLTestCase):
         if cls.tz_prev is not None:
             os.environ["TZ"] = cls.tz_prev
         time.tzset()
-        ReusedSQLTestCase.tearDownClass()
+        super(ArrowTests, cls).tearDownClass()
 
     def create_pandas_data_frame(self):
         import pandas as pd
@@ -4166,15 +4168,18 @@ class ArrowTests(ReusedSQLTestCase):
             schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
             df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
             with QuietTest(self.sc):
-                with warnings.catch_warnings(record=True) as warns:
-                    pdf = df.toPandas()
-                    # Catch and check the last UserWarning.
-                    user_warns = [
-                        warn.message for warn in warns if isinstance(warn.message, UserWarning)]
-                    self.assertTrue(len(user_warns) > 0)
-                    self.assertTrue(
-                        "Attempting non-optimization" in _exception_message(user_warns[-1]))
-                    self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
+                with self.warnings_lock:
+                    with warnings.catch_warnings(record=True) as warns:
+                        # we want the warnings to appear even if this test is run from a subclass
+                        warnings.simplefilter("always")
+                        pdf = df.toPandas()
+                        # Catch and check the last UserWarning.
+                        user_warns = [
+                            warn.message for warn in warns if isinstance(warn.message, UserWarning)]
+                        self.assertTrue(len(user_warns) > 0)
+                        self.assertTrue(
+                            "Attempting non-optimization" in _exception_message(user_warns[-1]))
+                        self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
 
     def test_toPandas_fallback_disabled(self):
         from distutils.version import LooseVersion
@@ -4183,8 +4188,9 @@ class ArrowTests(ReusedSQLTestCase):
         schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
         df = self.spark.createDataFrame([(None,)], schema=schema)
         with QuietTest(self.sc):
-            with self.assertRaisesRegexp(Exception, 'Unsupported type'):
-                df.toPandas()
+            with self.warnings_lock:
+                with self.assertRaisesRegexp(Exception, 'Unsupported type'):
+                    df.toPandas()
 
         # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
         if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
@@ -4396,6 +4402,8 @@ class ArrowTests(ReusedSQLTestCase):
         with QuietTest(self.sc):
             with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
                 with warnings.catch_warnings(record=True) as warns:
+                    # we want the warnings to appear even if this test is run from a subclass
+                    warnings.simplefilter("always")
                     df = self.spark.createDataFrame(
                         pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
                     # Catch and check the last UserWarning.
@@ -4439,6 +4447,13 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertPandasEqual(pdf, df_from_pandas.toPandas())
 
 
+class EncryptionArrowTests(ArrowTests):
+
+    @classmethod
+    def conf(cls):
+        return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true")
+
+
 @unittest.skipIf(
     not _have_pandas or not _have_pyarrow,
     _pandas_requirement_message or _pyarrow_requirement_message)

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/test_broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/test_broadcast.py
new file mode 100644
index 0000000..a00329c
--- /dev/null
+++ b/python/pyspark/test_broadcast.py
@@ -0,0 +1,126 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import random
+import tempfile
+import unittest
+
+try:
+    import xmlrunner
+except ImportError:
+    xmlrunner = None
+
+from pyspark.broadcast import Broadcast
+from pyspark.conf import SparkConf
+from pyspark.context import SparkContext
+from pyspark.java_gateway import launch_gateway
+from pyspark.serializers import ChunkedStream
+
+
+class BroadcastTest(unittest.TestCase):
+
+    def tearDown(self):
+        if getattr(self, "sc", None) is not None:
+            self.sc.stop()
+            self.sc = None
+
+    def _test_encryption_helper(self, vs):
+        """
+        Creates a broadcast variables for each value in vs, and runs a simple job to make sure the
+        value is the same when it's read in the executors.  Also makes sure there are no task
+        failures.
+        """
+        bs = [self.sc.broadcast(value=v) for v in vs]
+        exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect()
+        for ev in exec_values:
+            self.assertEqual(ev, vs)
+        # make sure there are no task failures
+        status = self.sc.statusTracker()
+        for jid in status.getJobIdsForGroup():
+            for sid in status.getJobInfo(jid).stageIds:
+                stage_info = status.getStageInfo(sid)
+                self.assertEqual(0, stage_info.numFailedTasks)
+
+    def _test_multiple_broadcasts(self, *extra_confs):
+        """
+        Test broadcast variables make it OK to the executors.  Tests multiple broadcast variables,
+        and also multiple jobs.
+        """
+        conf = SparkConf()
+        for key, value in extra_confs:
+            conf.set(key, value)
+        conf.setMaster("local-cluster[2,1,1024]")
+        self.sc = SparkContext(conf=conf)
+        self._test_encryption_helper([5])
+        self._test_encryption_helper([5, 10, 20])
+
+    def test_broadcast_with_encryption(self):
+        self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true"))
+
+    def test_broadcast_no_encryption(self):
+        self._test_multiple_broadcasts()
+
+
+class BroadcastFrameProtocolTest(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        gateway = launch_gateway(SparkConf())
+        cls._jvm = gateway.jvm
+        cls.longMessage = True
+        random.seed(42)
+
+    def _test_chunked_stream(self, data, py_buf_size):
+        # write data using the chunked protocol from python.
+        chunked_file = tempfile.NamedTemporaryFile(delete=False)
+        dechunked_file = tempfile.NamedTemporaryFile(delete=False)
+        dechunked_file.close()
+        try:
+            out = ChunkedStream(chunked_file, py_buf_size)
+            out.write(data)
+            out.close()
+            # now try to read it in java
+            jin = self._jvm.java.io.FileInputStream(chunked_file.name)
+            jout = self._jvm.java.io.FileOutputStream(dechunked_file.name)
+            self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout)
+            # java should have decoded it back to the original data
+            self.assertEqual(len(data), os.stat(dechunked_file.name).st_size)
+            with open(dechunked_file.name, "rb") as f:
+                byte = f.read(1)
+                idx = 0
+                while byte:
+                    self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx))
+                    byte = f.read(1)
+                    idx += 1
+        finally:
+            os.unlink(chunked_file.name)
+            os.unlink(dechunked_file.name)
+
+    def test_chunked_stream(self):
+        def random_bytes(n):
+            return bytearray(random.getrandbits(8) for _ in range(n))
+        for data_length in [1, 10, 100, 10000]:
+            for buffer_length in [1, 2, 5, 8192]:
+                self._test_chunked_stream(random_bytes(data_length), buffer_length)
+
+if __name__ == '__main__':
+    from pyspark.test_broadcast import *
+    if xmlrunner:
+        unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
+    else:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/test_serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py
new file mode 100644
index 0000000..5b43729
--- /dev/null
+++ b/python/pyspark/test_serializers.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import io
+import math
+import struct
+import sys
+import unittest
+
+try:
+    import xmlrunner
+except ImportError:
+    xmlrunner = None
+
+from pyspark import serializers
+
+
+def read_int(b):
+    return struct.unpack("!i", b)[0]
+
+
+def write_int(i):
+    return struct.pack("!i", i)
+
+
+class SerializersTest(unittest.TestCase):
+
+    def test_chunked_stream(self):
+        original_bytes = bytearray(range(100))
+        for data_length in [1, 10, 100]:
+            for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]:
+                dest = ByteArrayOutput()
+                stream_out = serializers.ChunkedStream(dest, buffer_length)
+                stream_out.write(original_bytes[:data_length])
+                stream_out.close()
+                num_chunks = int(math.ceil(float(data_length) / buffer_length))
+                # length for each chunk, and a final -1 at the very end
+                exp_size = (num_chunks + 1) * 4 + data_length
+                self.assertEqual(len(dest.buffer), exp_size)
+                dest_pos = 0
+                data_pos = 0
+                for chunk_idx in range(num_chunks):
+                    chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)])
+                    if chunk_idx == num_chunks - 1:
+                        exp_length = data_length % buffer_length
+                        if exp_length == 0:
+                            exp_length = buffer_length
+                    else:
+                        exp_length = buffer_length
+                    self.assertEqual(chunk_length, exp_length)
+                    dest_pos += 4
+                    dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length]
+                    orig_chunk = original_bytes[data_pos:data_pos + chunk_length]
+                    self.assertEqual(dest_chunk, orig_chunk)
+                    dest_pos += chunk_length
+                    data_pos += chunk_length
+                # ends with a -1
+                self.assertEqual(dest.buffer[-4:], write_int(-1))
+
+
+class ByteArrayOutput(object):
+    def __init__(self):
+        self.buffer = bytearray()
+
+    def write(self, b):
+        self.buffer += b
+
+    def close(self):
+        pass
+
+if __name__ == '__main__':
+    from pyspark.test_serializers import *
+    if xmlrunner:
+        unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
+    else:
+        unittest.main(verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 8ac1df5..050c2dd 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -373,8 +373,15 @@ class PySparkTestCase(unittest.TestCase):
 class ReusedPySparkTestCase(unittest.TestCase):
 
     @classmethod
+    def conf(cls):
+        """
+        Override this in subclasses to supply a more specific conf
+        """
+        return SparkConf()
+
+    @classmethod
     def setUpClass(cls):
-        cls.sc = SparkContext('local[4]', cls.__name__)
+        cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf())
 
     @classmethod
     def tearDownClass(cls):

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index e934da4..974344f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -324,16 +324,34 @@ def main(infile, outfile):
             importlib.invalidate_caches()
 
         # fetch names and values of broadcast variables
+        needs_broadcast_decryption_server = read_bool(infile)
         num_broadcast_variables = read_int(infile)
+        if needs_broadcast_decryption_server:
+            # read the decrypted data from a server in the jvm
+            port = read_int(infile)
+            auth_secret = utf8_deserializer.loads(infile)
+            (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)
+
         for _ in range(num_broadcast_variables):
             bid = read_long(infile)
             if bid >= 0:
-                path = utf8_deserializer.loads(infile)
-                _broadcastRegistry[bid] = Broadcast(path=path)
+                if needs_broadcast_decryption_server:
+                    read_bid = read_long(broadcast_sock_file)
+                    assert(read_bid == bid)
+                    _broadcastRegistry[bid] = \
+                        Broadcast(sock_file=broadcast_sock_file)
+                else:
+                    path = utf8_deserializer.loads(infile)
+                    _broadcastRegistry[bid] = Broadcast(path=path)
+
             else:
                 bid = - bid - 1
                 _broadcastRegistry.pop(bid)
 
+        if needs_broadcast_decryption_server:
+            broadcast_sock_file.write(b'1')
+            broadcast_sock_file.close()
+
         _accumulatorRegistry.clear()
         eval_type = read_int(infile)
         if eval_type == PythonEvalType.NON_UDF:

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index c0830e7..482e2bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -17,6 +17,12 @@
 
 package org.apache.spark.sql.api.python
 
+import java.io.InputStream
+import java.nio.channels.Channels
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.PythonRDDServer
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
@@ -33,19 +39,36 @@ private[sql] object PythonSQLUtils {
   }
 
   /**
-   * Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
+   * Python callable function to read a file in Arrow stream format and create a [[RDD]]
    * using each serialized ArrowRecordBatch as a partition.
-   *
-   * @param sqlContext The active [[SQLContext]].
-   * @param filename File to read the Arrow stream from.
-   * @param schemaString JSON Formatted Spark schema for Arrow batches.
-   * @return A new [[DataFrame]].
    */
-  def arrowReadStreamFromFile(
-      sqlContext: SQLContext,
-      filename: String,
-      schemaString: String): DataFrame = {
-    val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename)
-    ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext)
+  def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): JavaRDD[Array[Byte]] = {
+    ArrowConverters.readArrowStreamFromFile(sqlContext, filename)
+  }
+
+  /**
+   * Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
+   * from an RDD.
+   */
+  def toDataFrame(
+      arrowBatchRDD: JavaRDD[Array[Byte]],
+      schemaString: String,
+      sqlContext: SQLContext): DataFrame = {
+    ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, sqlContext)
   }
 }
+
+/**
+ * Helper for making a dataframe from arrow data from data sent from python over a socket.  This is
+ * used when encryption is enabled, and we don't want to write data to a file.
+ */
+private[sql] class ArrowRDDServer(sqlContext: SQLContext) extends PythonRDDServer {
+
+  override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
+    // Create array to consume iterator so that we can safely close the inputStream
+    val batches = ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray
+    // Parallelize the record batches to create an RDD
+    JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 1a48bc8..2bf6a58 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.arrow
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream}
-import java.nio.channels.{Channels, SeekableByteChannel}
+import java.nio.channels.{Channels, ReadableByteChannel}
 
 import scala.collection.JavaConverters._
 
@@ -31,6 +31,7 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer}
 import org.apache.spark.TaskContext
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
@@ -189,7 +190,7 @@ private[sql] object ArrowConverters {
   }
 
   /**
-   * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+   * Create a DataFrame from an RDD of serialized ArrowRecordBatches.
    */
   private[sql] def toDataFrame(
       arrowBatchRDD: JavaRDD[Array[Byte]],
@@ -221,7 +222,7 @@ private[sql] object ArrowConverters {
   /**
    * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches.
    */
-  private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = {
+  private[sql] def getBatchesFromStream(in: ReadableByteChannel): Iterator[Array[Byte]] = {
 
     // Iterate over the serialized Arrow RecordBatch messages from a stream
     new Iterator[Array[Byte]] {
@@ -271,7 +272,7 @@ private[sql] object ArrowConverters {
         } else {
           if (bodyLength > 0) {
             // Skip message body if not a RecordBatch
-            in.position(in.position() + bodyLength)
+            Channels.newInputStream(in).skip(bodyLength)
           }
 
           // Proceed to next message


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org