You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ir...@apache.org on 2018/09/17 19:40:35 UTC
[3/3] spark git commit: [PYSPARK] Updates to pyspark broadcast
[PYSPARK] Updates to pyspark broadcast
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/58419b92
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/58419b92
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/58419b92
Branch: refs/heads/master
Commit: 58419b92673c46911c25bc6c6b13397f880c6424
Parents: 553af22
Author: Imran Rashid <ir...@cloudera.com>
Authored: Mon Aug 13 21:35:34 2018 -0500
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Mon Sep 17 14:06:09 2018 -0500
----------------------------------------------------------------------
.../org/apache/spark/api/python/PythonRDD.scala | 299 ++++++++++++++++---
.../apache/spark/api/python/PythonRunner.scala | 52 +++-
.../spark/api/python/PythonRDDSuite.scala | 23 +-
dev/sparktestsupport/modules.py | 2 +
python/pyspark/broadcast.py | 58 +++-
python/pyspark/context.py | 64 ++--
python/pyspark/serializers.py | 51 ++++
python/pyspark/sql/session.py | 12 +-
python/pyspark/sql/tests.py | 45 ++-
python/pyspark/test_broadcast.py | 126 ++++++++
python/pyspark/test_serializers.py | 90 ++++++
python/pyspark/tests.py | 9 +-
python/pyspark/worker.py | 22 +-
.../spark/sql/api/python/PythonSQLUtils.scala | 47 ++-
.../sql/execution/arrow/ArrowConverters.scala | 9 +-
15 files changed, 789 insertions(+), 120 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index e639a84..8b5a7a9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -24,8 +24,10 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
import scala.language.existentials
-import scala.util.control.NonFatal
+import scala.util.Try
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
@@ -37,6 +39,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._
@@ -169,27 +172,34 @@ private[spark] object PythonRDD extends Logging {
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
- val file = new DataInputStream(new FileInputStream(filename))
+ readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism)
+ }
+
+ def readRDDFromInputStream(
+ sc: SparkContext,
+ in: InputStream,
+ parallelism: Int): JavaRDD[Array[Byte]] = {
+ val din = new DataInputStream(in)
try {
val objs = new mutable.ArrayBuffer[Array[Byte]]
try {
while (true) {
- val length = file.readInt()
+ val length = din.readInt()
val obj = new Array[Byte](length)
- file.readFully(obj)
+ din.readFully(obj)
objs += obj
}
} catch {
case eof: EOFException => // No-op
}
- JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ JavaRDD.fromRDD(sc.parallelize(objs, parallelism))
} finally {
- file.close()
+ din.close()
}
}
- def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = {
- sc.broadcast(new PythonBroadcast(path))
+ def setupBroadcast(path: String): PythonBroadcast = {
+ new PythonBroadcast(path)
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -419,34 +429,15 @@ private[spark] object PythonRDD extends Logging {
*/
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
- val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
- // Close the socket if no connection in 15 seconds
- serverSocket.setSoTimeout(15000)
-
- new Thread(threadName) {
- setDaemon(true)
- override def run() {
- try {
- val sock = serverSocket.accept()
- authHelper.authClient(sock)
-
- val out = new BufferedOutputStream(sock.getOutputStream)
- Utils.tryWithSafeFinally {
- writeFunc(out)
- } {
- out.close()
- sock.close()
- }
- } catch {
- case NonFatal(e) =>
- logError(s"Error while sending iterator", e)
- } finally {
- serverSocket.close()
- }
+ val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
+ val out = new BufferedOutputStream(s.getOutputStream())
+ Utils.tryWithSafeFinally {
+ writeFunc(out)
+ } {
+ out.close()
}
- }.start()
-
- Array(serverSocket.getLocalPort, authHelper.secret)
+ }
+ Array(port, secret)
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
@@ -664,13 +655,11 @@ private[spark] class PythonAccumulatorV2(
}
}
-/**
- * A Wrapper for Python Broadcast, which is written into disk by Python. It also will
- * write the data into disk after deserialization, then Python can read it from disks.
- */
// scalastyle:off no.finalize
private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
- with Logging {
+ with Logging {
+
+ private var encryptionServer: PythonServer[Unit] = null
/**
* Read data from disks, then copy it to `out`
@@ -713,5 +702,235 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
}
super.finalize()
}
+
+ def setupEncryptionServer(): Array[Any] = {
+ encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") {
+ override def handleConnection(sock: Socket): Unit = {
+ val env = SparkEnv.get
+ val in = sock.getInputStream()
+ val dir = new File(Utils.getLocalDir(env.conf))
+ val file = File.createTempFile("broadcast", "", dir)
+ path = file.getAbsolutePath
+ val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path))
+ DechunkedInputStream.dechunkAndCopyToOutput(in, out)
+ }
+ }
+ Array(encryptionServer.port, encryptionServer.secret)
+ }
+
+ def waitTillDataReceived(): Unit = encryptionServer.getResult()
}
// scalastyle:on no.finalize
+
+/**
+ * The inverse of pyspark's ChunkedStream for sending data of unknown size.
+ *
+ * We might be serializing a really large object from python -- we don't want
+ * python to buffer the whole thing in memory, nor can it write to a file,
+ * so we don't know the length in advance. So python writes it in chunks, each chunk
+ * preceeded by a length, till we get a "length" of -1 which serves as EOF.
+ *
+ * Tested from python tests.
+ */
+private[spark] class DechunkedInputStream(wrapped: InputStream) extends InputStream with Logging {
+ private val din = new DataInputStream(wrapped)
+ private var remainingInChunk = din.readInt()
+
+ override def read(): Int = {
+ val into = new Array[Byte](1)
+ val n = read(into, 0, 1)
+ if (n == -1) {
+ -1
+ } else {
+ // if you just cast a byte to an int, then anything > 127 is negative, which is interpreted
+ // as an EOF
+ val b = into(0)
+ if (b < 0) {
+ 256 + b
+ } else {
+ b
+ }
+ }
+ }
+
+ override def read(dest: Array[Byte], off: Int, len: Int): Int = {
+ if (remainingInChunk == -1) {
+ return -1
+ }
+ var destSpace = len
+ var destPos = off
+ while (destSpace > 0 && remainingInChunk != -1) {
+ val toCopy = math.min(remainingInChunk, destSpace)
+ val read = din.read(dest, destPos, toCopy)
+ destPos += read
+ destSpace -= read
+ remainingInChunk -= read
+ if (remainingInChunk == 0) {
+ remainingInChunk = din.readInt()
+ }
+ }
+ assert(destSpace == 0 || remainingInChunk == -1)
+ return destPos - off
+ }
+
+ override def close(): Unit = wrapped.close()
+}
+
+private[spark] object DechunkedInputStream {
+
+ /**
+ * Dechunks the input, copies to output, and closes both input and the output safely.
+ */
+ def dechunkAndCopyToOutput(chunked: InputStream, out: OutputStream): Unit = {
+ val dechunked = new DechunkedInputStream(chunked)
+ Utils.tryWithSafeFinally {
+ Utils.copyStream(dechunked, out)
+ } {
+ JavaUtils.closeQuietly(out)
+ JavaUtils.closeQuietly(dechunked)
+ }
+ }
+}
+
+/**
+ * Creates a server in the jvm to communicate with python for handling one batch of data, with
+ * authentication and error handling.
+ */
+private[spark] abstract class PythonServer[T](
+ authHelper: SocketAuthHelper,
+ threadName: String) {
+
+ def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
+ def this(threadName: String) = this(SparkEnv.get, threadName)
+
+ val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock =>
+ promise.complete(Try(handleConnection(sock)))
+ }
+
+ /**
+ * Handle a connection which has already been authenticated. Any error from this function
+ * will clean up this connection and the entire server, and get propogated to [[getResult]].
+ */
+ def handleConnection(sock: Socket): T
+
+ val promise = Promise[T]()
+
+ /**
+ * Blocks indefinitely for [[handleConnection]] to finish, and returns that result. If
+ * handleConnection throws an exception, this will throw an exception which includes the original
+ * exception as a cause.
+ */
+ def getResult(): T = {
+ getResult(Duration.Inf)
+ }
+
+ def getResult(wait: Duration): T = {
+ ThreadUtils.awaitResult(promise.future, wait)
+ }
+
+}
+
+private[spark] object PythonServer {
+
+ /**
+ * Create a socket server and run user function on the socket in a background thread.
+ *
+ * The socket server can only accept one connection, or close if no connection
+ * in 15 seconds.
+ *
+ * The thread will terminate after the supplied user function, or if there are any exceptions.
+ *
+ * If you need to get a result of the supplied function, create a subclass of [[PythonServer]]
+ *
+ * @return The port number of a local socket and the secret for authentication.
+ */
+ def setupOneConnectionServer(
+ authHelper: SocketAuthHelper,
+ threadName: String)
+ (func: Socket => Unit): (Int, String) = {
+ val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
+ // Close the socket if no connection in 15 seconds
+ serverSocket.setSoTimeout(15000)
+
+ new Thread(threadName) {
+ setDaemon(true)
+ override def run(): Unit = {
+ var sock: Socket = null
+ try {
+ sock = serverSocket.accept()
+ authHelper.authClient(sock)
+ func(sock)
+ } finally {
+ JavaUtils.closeQuietly(serverSocket)
+ JavaUtils.closeQuietly(sock)
+ }
+ }
+ }.start()
+ (serverSocket.getLocalPort, authHelper.secret)
+ }
+}
+
+/**
+ * Sends decrypted broadcast data to python worker. See [[PythonRunner]] for entire protocol.
+ */
+private[spark] class EncryptedPythonBroadcastServer(
+ val env: SparkEnv,
+ val idsAndFiles: Seq[(Long, String)])
+ extends PythonServer[Unit]("broadcast-decrypt-server") with Logging {
+
+ override def handleConnection(socket: Socket): Unit = {
+ val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()))
+ var socketIn: InputStream = null
+ // send the broadcast id, then the decrypted data. We don't need to send the length, the
+ // the python pickle module just needs a stream.
+ Utils.tryWithSafeFinally {
+ (idsAndFiles).foreach { case (id, path) =>
+ out.writeLong(id)
+ val in = env.serializerManager.wrapForEncryption(new FileInputStream(path))
+ Utils.tryWithSafeFinally {
+ Utils.copyStream(in, out, false)
+ } {
+ in.close()
+ }
+ }
+ logTrace("waiting for python to accept broadcast data over socket")
+ out.flush()
+ socketIn = socket.getInputStream()
+ socketIn.read()
+ logTrace("done serving broadcast data")
+ } {
+ JavaUtils.closeQuietly(socketIn)
+ JavaUtils.closeQuietly(out)
+ }
+ }
+
+ def waitTillBroadcastDataSent(): Unit = {
+ getResult()
+ }
+}
+
+/**
+ * Helper for making RDD[Array[Byte]] from some python data, by reading the data from python
+ * over a socket. This is used in preference to writing data to a file when encryption is enabled.
+ */
+private[spark] abstract class PythonRDDServer
+ extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
+
+ def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
+ val in = sock.getInputStream()
+ val dechunkedInput: InputStream = new DechunkedInputStream(in)
+ streamToRDD(dechunkedInput)
+ }
+
+ protected def streamToRDD(input: InputStream): RDD[Array[Byte]]
+
+}
+
+private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
+ extends PythonRDDServer {
+
+ override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
+ PythonRDD.readRDDFromInputStream(sc, input, parallelism)
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 4c53bc2..6e53a04 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -289,19 +289,51 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
- val cnt = toRemove.size + newBids.diff(oldBids).size
+ val addedBids = newBids.diff(oldBids)
+ val cnt = toRemove.size + addedBids.size
+ val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty
+ dataOut.writeBoolean(needsDecryptionServer)
dataOut.writeInt(cnt)
- for (bid <- toRemove) {
- // remove the broadcast from worker
- dataOut.writeLong(- bid - 1) // bid >= 0
- oldBids.remove(bid)
+ def sendBidsToRemove(): Unit = {
+ for (bid <- toRemove) {
+ // remove the broadcast from worker
+ dataOut.writeLong(-bid - 1) // bid >= 0
+ oldBids.remove(bid)
+ }
}
- for (broadcast <- broadcastVars) {
- if (!oldBids.contains(broadcast.id)) {
+ if (needsDecryptionServer) {
+ // if there is encryption, we setup a server which reads the encrypted files, and sends
+ // the decrypted data to python
+ val idsAndFiles = broadcastVars.flatMap { broadcast =>
+ if (!oldBids.contains(broadcast.id)) {
+ Some((broadcast.id, broadcast.value.path))
+ } else {
+ None
+ }
+ }
+ val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
+ dataOut.writeInt(server.port)
+ logTrace(s"broadcast decryption server setup on ${server.port}")
+ PythonRDD.writeUTF(server.secret, dataOut)
+ sendBidsToRemove()
+ idsAndFiles.foreach { case (id, _) =>
// send new broadcast
- dataOut.writeLong(broadcast.id)
- PythonRDD.writeUTF(broadcast.value.path, dataOut)
- oldBids.add(broadcast.id)
+ dataOut.writeLong(id)
+ oldBids.add(id)
+ }
+ dataOut.flush()
+ logTrace("waiting for python to read decrypted broadcast data from server")
+ server.waitTillBroadcastDataSent()
+ logTrace("done sending decrypted data to python")
+ } else {
+ sendBidsToRemove()
+ for (broadcast <- broadcastVars) {
+ if (!oldBids.contains(broadcast.id)) {
+ // send new broadcast
+ dataOut.writeLong(broadcast.id)
+ PythonRDD.writeUTF(broadcast.value.path, dataOut)
+ oldBids.add(broadcast.id)
+ }
}
}
dataOut.flush()
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 05b4e67..6f9b583 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -18,9 +18,13 @@
package org.apache.spark.api.python
import java.io.{ByteArrayOutputStream, DataOutputStream}
+import java.net.{InetAddress, Socket}
import java.nio.charset.StandardCharsets
-import org.apache.spark.SparkFunSuite
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.security.SocketAuthHelper
class PythonRDDSuite extends SparkFunSuite {
@@ -44,4 +48,21 @@ class PythonRDDSuite extends SparkFunSuite {
("a".getBytes(StandardCharsets.UTF_8), null),
(null, "b".getBytes(StandardCharsets.UTF_8))), buffer)
}
+
+ test("python server error handling") {
+ val authHelper = new SocketAuthHelper(new SparkConf())
+ val errorServer = new ExceptionPythonServer(authHelper)
+ val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
+ authHelper.authToServer(client)
+ val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
+ assert(ex.getCause().getMessage().contains("exception within handleConnection"))
+ }
+
+ class ExceptionPythonServer(authHelper: SocketAuthHelper)
+ extends PythonServer[Unit](authHelper, "error-server") {
+
+ override def handleConnection(sock: Socket): Unit = {
+ throw new Exception("exception within handleConnection")
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/dev/sparktestsupport/modules.py
----------------------------------------------------------------------
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2aa3555..e267fbf 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -387,6 +387,8 @@ pyspark_core = Module(
"pyspark.profiler",
"pyspark.shuffle",
"pyspark.tests",
+ "pyspark.test_broadcast",
+ "pyspark.test_serializers",
"pyspark.util",
]
)
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index b3dfc99..1c7f2a7 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -15,13 +15,16 @@
# limitations under the License.
#
+import gc
import os
+import socket
import sys
-import gc
from tempfile import NamedTemporaryFile
import threading
from pyspark.cloudpickle import print_exec
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import ChunkedStream
from pyspark.util import _exception_message
if sys.version < '3':
@@ -64,19 +67,43 @@ class Broadcast(object):
>>> large_broadcast = sc.broadcast(range(10000))
"""
- def __init__(self, sc=None, value=None, pickle_registry=None, path=None):
+ def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
+ sock_file=None):
"""
Should not be called directly by users -- use L{SparkContext.broadcast()}
instead.
"""
if sc is not None:
+ # we're on the driver. We want the pickled data to end up in a file (maybe encrypted)
f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
- self._path = self.dump(value, f)
- self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path)
+ self._path = f.name
+ python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
+ if sc._encryption_enabled:
+ # with encryption, we ask the jvm to do the encryption for us, we send it data
+ # over a socket
+ port, auth_secret = python_broadcast.setupEncryptionServer()
+ (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
+ broadcast_out = ChunkedStream(encryption_sock_file, 8192)
+ else:
+ # no encryption, we can just write pickled data directly to the file from python
+ broadcast_out = f
+ self.dump(value, broadcast_out)
+ if sc._encryption_enabled:
+ python_broadcast.waitTillDataReceived()
+ self._jbroadcast = sc._jsc.broadcast(python_broadcast)
self._pickle_registry = pickle_registry
else:
+ # we're on an executor
self._jbroadcast = None
- self._path = path
+ if sock_file is not None:
+ # the jvm is doing decryption for us. Read the value
+ # immediately from the sock_file
+ self._value = self.load(sock_file)
+ else:
+ # the jvm just dumps the pickled data in path -- we'll unpickle lazily when
+ # the value is requested
+ assert(path is not None)
+ self._path = path
def dump(self, value, f):
try:
@@ -89,24 +116,25 @@ class Broadcast(object):
print_exec(sys.stderr)
raise pickle.PicklingError(msg)
f.close()
- return f.name
- def load(self, path):
+ def load_from_path(self, path):
with open(path, 'rb', 1 << 20) as f:
- # pickle.load() may create lots of objects, disable GC
- # temporary for better performance
- gc.disable()
- try:
- return pickle.load(f)
- finally:
- gc.enable()
+ return self.load(f)
+
+ def load(self, file):
+ # "file" could also be a socket
+ gc.disable()
+ try:
+ return pickle.load(file)
+ finally:
+ gc.enable()
@property
def value(self):
""" Return the broadcasted value
"""
if not hasattr(self, "_value") and self._path is not None:
- self._value = self.load(self._path)
+ self._value = self.load_from_path(self._path)
return self._value
def unpersist(self, blocking=False):
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 4cabae4..2c92c29 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -33,9 +33,9 @@ from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
-from pyspark.java_gateway import launch_gateway
+from pyspark.java_gateway import launch_gateway, local_connect_and_auth
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
- PairDeserializer, AutoBatchedSerializer, NoOpSerializer
+ PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.traceback_utils import CallSite, first_spark_call
@@ -189,6 +189,13 @@ class SparkContext(object):
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
self._jsc.sc().register(self._javaAccumulator)
+ # If encryption is enabled, we need to setup a server in the jvm to read broadcast
+ # data via a socket.
+ # scala's mangled names w/ $ in them require special treatment.
+ encryption_conf = self._jvm.org.apache.spark.internal.config.__getattr__("package$")\
+ .__getattr__("MODULE$").IO_ENCRYPTION_ENABLED()
+ self._encryption_enabled = self._jsc.sc().conf().get(encryption_conf)
+
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
self.pythonVer = "%d.%d" % sys.version_info[:2]
@@ -498,23 +505,46 @@ class SparkContext(object):
def reader_func(temp_filename):
return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices)
- jrdd = self._serialize_to_jvm(c, serializer, reader_func)
+ def createRDDServer():
+ return self._jvm.PythonParallelizeServer(self._jsc.sc(), numSlices)
+
+ jrdd = self._serialize_to_jvm(c, serializer, reader_func, createRDDServer)
return RDD(jrdd, self, serializer)
- def _serialize_to_jvm(self, data, serializer, reader_func):
- """
- Calling the Java parallelize() method with an ArrayList is too slow,
- because it sends O(n) Py4J commands. As an alternative, serialized
- objects are written to a file and loaded through textFile().
- """
- tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
- try:
- serializer.dump_stream(data, tempFile)
- tempFile.close()
- return reader_func(tempFile.name)
- finally:
- # readRDDFromFile eagerily reads the file so we can delete right after.
- os.unlink(tempFile.name)
+ def _serialize_to_jvm(self, data, serializer, reader_func, createRDDServer):
+ """
+ Using py4j to send a large dataset to the jvm is really slow, so we use either a file
+ or a socket if we have encryption enabled.
+ :param data:
+ :param serializer:
+ :param reader_func: A function which takes a filename and reads in the data in the jvm and
+ returns a JavaRDD. Only used when encryption is disabled.
+ :param createRDDServer: A function which creates a PythonRDDServer in the jvm to
+ accept the serialized data, for use when encryption is enabled.
+ :return:
+ """
+ if self._encryption_enabled:
+ # with encryption, we open a server in java and send the data directly
+ server = createRDDServer()
+ (sock_file, _) = local_connect_and_auth(server.port(), server.secret())
+ chunked_out = ChunkedStream(sock_file, 8192)
+ serializer.dump_stream(data, chunked_out)
+ chunked_out.close()
+ # this call will block until the server has read all the data and processed it (or
+ # throws an exception)
+ r = server.getResult()
+ return r
+ else:
+ # without encryption, we serialize to a file, and we read the file in java and
+ # parallelize from there.
+ tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
+ try:
+ serializer.dump_stream(data, tempFile)
+ tempFile.close()
+ return reader_func(tempFile.name)
+ finally:
+ # we eagerily reads the file so we can delete right after.
+ os.unlink(tempFile.name)
def pickleFile(self, name, minPartitions=None):
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 4800677..ff9a612 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -731,6 +731,57 @@ def write_with_length(obj, stream):
stream.write(obj)
+class ChunkedStream(object):
+
+ """
+ This is a file-like object takes a stream of data, of unknown length, and breaks it into fixed
+ length frames. The intended use case is serializing large data and sending it immediately over
+ a socket -- we do not want to buffer the entire data before sending it, but the receiving end
+ needs to know whether or not there is more data coming.
+
+ It works by buffering the incoming data in some fixed-size chunks. If the buffer is full, it
+ first sends the buffer size, then the data. This repeats as long as there is more data to send.
+ When this is closed, it sends the length of whatever data is in the buffer, then that data, and
+ finally a "length" of -1 to indicate the stream has completed.
+ """
+
+ def __init__(self, wrapped, buffer_size):
+ self.buffer_size = buffer_size
+ self.buffer = bytearray(buffer_size)
+ self.current_pos = 0
+ self.wrapped = wrapped
+
+ def write(self, bytes):
+ byte_pos = 0
+ byte_remaining = len(bytes)
+ while byte_remaining > 0:
+ new_pos = byte_remaining + self.current_pos
+ if new_pos < self.buffer_size:
+ # just put it in our buffer
+ self.buffer[self.current_pos:new_pos] = bytes[byte_pos:]
+ self.current_pos = new_pos
+ byte_remaining = 0
+ else:
+ # fill the buffer, send the length then the contents, and start filling again
+ space_left = self.buffer_size - self.current_pos
+ new_byte_pos = byte_pos + space_left
+ self.buffer[self.current_pos:self.buffer_size] = bytes[byte_pos:new_byte_pos]
+ write_int(self.buffer_size, self.wrapped)
+ self.wrapped.write(self.buffer)
+ byte_remaining -= space_left
+ byte_pos = new_byte_pos
+ self.current_pos = 0
+
+ def close(self):
+ # if there is anything left in the buffer, write it out first
+ if self.current_pos > 0:
+ write_int(self.current_pos, self.wrapped)
+ self.wrapped.write(self.buffer[:self.current_pos])
+ # -1 length indicates to the receiving end that we're done.
+ write_int(-1, self.wrapped)
+ self.wrapped.close()
+
+
if __name__ == '__main__':
import doctest
(failure_count, test_count) = doctest.testmod()
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/sql/session.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 87d8d6a..51a38eb 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -539,12 +539,18 @@ class SparkSession(object):
struct.names[i] = name
schema = struct
+ jsqlContext = self._wrapped._jsqlContext
+
def reader_func(temp_filename):
- return self._jvm.PythonSQLUtils.arrowReadStreamFromFile(
- self._wrapped._jsqlContext, temp_filename, schema.json())
+ return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
+
+ def create_RDD_server():
+ return self._jvm.ArrowRDDServer(jsqlContext)
# Create Spark DataFrame from Arrow stream file, using one batch per partition
- jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func)
+ jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func,
+ create_RDD_server)
+ jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 8e5bc67..08d7cfa 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -26,6 +26,7 @@ import subprocess
import pydoc
import shutil
import tempfile
+import threading
import pickle
import functools
import time
@@ -228,12 +229,12 @@ class SQLTestUtils(object):
class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
@classmethod
def setUpClass(cls):
- ReusedPySparkTestCase.setUpClass()
+ super(ReusedSQLTestCase, cls).setUpClass()
cls.spark = SparkSession(cls.sc)
@classmethod
def tearDownClass(cls):
- ReusedPySparkTestCase.tearDownClass()
+ super(ReusedSQLTestCase, cls).tearDownClass()
cls.spark.stop()
def assertPandasEqual(self, expected, result):
@@ -4105,7 +4106,8 @@ class ArrowTests(ReusedSQLTestCase):
from decimal import Decimal
from distutils.version import LooseVersion
import pyarrow as pa
- ReusedSQLTestCase.setUpClass()
+ super(ArrowTests, cls).setUpClass()
+ cls.warnings_lock = threading.Lock()
# Synchronize default timezone between Python and Java
cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
@@ -4146,7 +4148,7 @@ class ArrowTests(ReusedSQLTestCase):
if cls.tz_prev is not None:
os.environ["TZ"] = cls.tz_prev
time.tzset()
- ReusedSQLTestCase.tearDownClass()
+ super(ArrowTests, cls).tearDownClass()
def create_pandas_data_frame(self):
import pandas as pd
@@ -4166,15 +4168,18 @@ class ArrowTests(ReusedSQLTestCase):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
with QuietTest(self.sc):
- with warnings.catch_warnings(record=True) as warns:
- pdf = df.toPandas()
- # Catch and check the last UserWarning.
- user_warns = [
- warn.message for warn in warns if isinstance(warn.message, UserWarning)]
- self.assertTrue(len(user_warns) > 0)
- self.assertTrue(
- "Attempting non-optimization" in _exception_message(user_warns[-1]))
- self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
+ with self.warnings_lock:
+ with warnings.catch_warnings(record=True) as warns:
+ # we want the warnings to appear even if this test is run from a subclass
+ warnings.simplefilter("always")
+ pdf = df.toPandas()
+ # Catch and check the last UserWarning.
+ user_warns = [
+ warn.message for warn in warns if isinstance(warn.message, UserWarning)]
+ self.assertTrue(len(user_warns) > 0)
+ self.assertTrue(
+ "Attempting non-optimization" in _exception_message(user_warns[-1]))
+ self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
def test_toPandas_fallback_disabled(self):
from distutils.version import LooseVersion
@@ -4183,8 +4188,9 @@ class ArrowTests(ReusedSQLTestCase):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
- with self.assertRaisesRegexp(Exception, 'Unsupported type'):
- df.toPandas()
+ with self.warnings_lock:
+ with self.assertRaisesRegexp(Exception, 'Unsupported type'):
+ df.toPandas()
# TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
@@ -4396,6 +4402,8 @@ class ArrowTests(ReusedSQLTestCase):
with QuietTest(self.sc):
with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
with warnings.catch_warnings(record=True) as warns:
+ # we want the warnings to appear even if this test is run from a subclass
+ warnings.simplefilter("always")
df = self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
# Catch and check the last UserWarning.
@@ -4439,6 +4447,13 @@ class ArrowTests(ReusedSQLTestCase):
self.assertPandasEqual(pdf, df_from_pandas.toPandas())
+class EncryptionArrowTests(ArrowTests):
+
+ @classmethod
+ def conf(cls):
+ return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true")
+
+
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/test_broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/test_broadcast.py
new file mode 100644
index 0000000..a00329c
--- /dev/null
+++ b/python/pyspark/test_broadcast.py
@@ -0,0 +1,126 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import random
+import tempfile
+import unittest
+
+try:
+ import xmlrunner
+except ImportError:
+ xmlrunner = None
+
+from pyspark.broadcast import Broadcast
+from pyspark.conf import SparkConf
+from pyspark.context import SparkContext
+from pyspark.java_gateway import launch_gateway
+from pyspark.serializers import ChunkedStream
+
+
+class BroadcastTest(unittest.TestCase):
+
+ def tearDown(self):
+ if getattr(self, "sc", None) is not None:
+ self.sc.stop()
+ self.sc = None
+
+ def _test_encryption_helper(self, vs):
+ """
+ Creates a broadcast variables for each value in vs, and runs a simple job to make sure the
+ value is the same when it's read in the executors. Also makes sure there are no task
+ failures.
+ """
+ bs = [self.sc.broadcast(value=v) for v in vs]
+ exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect()
+ for ev in exec_values:
+ self.assertEqual(ev, vs)
+ # make sure there are no task failures
+ status = self.sc.statusTracker()
+ for jid in status.getJobIdsForGroup():
+ for sid in status.getJobInfo(jid).stageIds:
+ stage_info = status.getStageInfo(sid)
+ self.assertEqual(0, stage_info.numFailedTasks)
+
+ def _test_multiple_broadcasts(self, *extra_confs):
+ """
+ Test broadcast variables make it OK to the executors. Tests multiple broadcast variables,
+ and also multiple jobs.
+ """
+ conf = SparkConf()
+ for key, value in extra_confs:
+ conf.set(key, value)
+ conf.setMaster("local-cluster[2,1,1024]")
+ self.sc = SparkContext(conf=conf)
+ self._test_encryption_helper([5])
+ self._test_encryption_helper([5, 10, 20])
+
+ def test_broadcast_with_encryption(self):
+ self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true"))
+
+ def test_broadcast_no_encryption(self):
+ self._test_multiple_broadcasts()
+
+
+class BroadcastFrameProtocolTest(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ gateway = launch_gateway(SparkConf())
+ cls._jvm = gateway.jvm
+ cls.longMessage = True
+ random.seed(42)
+
+ def _test_chunked_stream(self, data, py_buf_size):
+ # write data using the chunked protocol from python.
+ chunked_file = tempfile.NamedTemporaryFile(delete=False)
+ dechunked_file = tempfile.NamedTemporaryFile(delete=False)
+ dechunked_file.close()
+ try:
+ out = ChunkedStream(chunked_file, py_buf_size)
+ out.write(data)
+ out.close()
+ # now try to read it in java
+ jin = self._jvm.java.io.FileInputStream(chunked_file.name)
+ jout = self._jvm.java.io.FileOutputStream(dechunked_file.name)
+ self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout)
+ # java should have decoded it back to the original data
+ self.assertEqual(len(data), os.stat(dechunked_file.name).st_size)
+ with open(dechunked_file.name, "rb") as f:
+ byte = f.read(1)
+ idx = 0
+ while byte:
+ self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx))
+ byte = f.read(1)
+ idx += 1
+ finally:
+ os.unlink(chunked_file.name)
+ os.unlink(dechunked_file.name)
+
+ def test_chunked_stream(self):
+ def random_bytes(n):
+ return bytearray(random.getrandbits(8) for _ in range(n))
+ for data_length in [1, 10, 100, 10000]:
+ for buffer_length in [1, 2, 5, 8192]:
+ self._test_chunked_stream(random_bytes(data_length), buffer_length)
+
+if __name__ == '__main__':
+ from pyspark.test_broadcast import *
+ if xmlrunner:
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
+ else:
+ unittest.main(verbosity=2)
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/test_serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py
new file mode 100644
index 0000000..5b43729
--- /dev/null
+++ b/python/pyspark/test_serializers.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import io
+import math
+import struct
+import sys
+import unittest
+
+try:
+ import xmlrunner
+except ImportError:
+ xmlrunner = None
+
+from pyspark import serializers
+
+
+def read_int(b):
+ return struct.unpack("!i", b)[0]
+
+
+def write_int(i):
+ return struct.pack("!i", i)
+
+
+class SerializersTest(unittest.TestCase):
+
+ def test_chunked_stream(self):
+ original_bytes = bytearray(range(100))
+ for data_length in [1, 10, 100]:
+ for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]:
+ dest = ByteArrayOutput()
+ stream_out = serializers.ChunkedStream(dest, buffer_length)
+ stream_out.write(original_bytes[:data_length])
+ stream_out.close()
+ num_chunks = int(math.ceil(float(data_length) / buffer_length))
+ # length for each chunk, and a final -1 at the very end
+ exp_size = (num_chunks + 1) * 4 + data_length
+ self.assertEqual(len(dest.buffer), exp_size)
+ dest_pos = 0
+ data_pos = 0
+ for chunk_idx in range(num_chunks):
+ chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)])
+ if chunk_idx == num_chunks - 1:
+ exp_length = data_length % buffer_length
+ if exp_length == 0:
+ exp_length = buffer_length
+ else:
+ exp_length = buffer_length
+ self.assertEqual(chunk_length, exp_length)
+ dest_pos += 4
+ dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length]
+ orig_chunk = original_bytes[data_pos:data_pos + chunk_length]
+ self.assertEqual(dest_chunk, orig_chunk)
+ dest_pos += chunk_length
+ data_pos += chunk_length
+ # ends with a -1
+ self.assertEqual(dest.buffer[-4:], write_int(-1))
+
+
+class ByteArrayOutput(object):
+ def __init__(self):
+ self.buffer = bytearray()
+
+ def write(self, b):
+ self.buffer += b
+
+ def close(self):
+ pass
+
+if __name__ == '__main__':
+ from pyspark.test_serializers import *
+ if xmlrunner:
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
+ else:
+ unittest.main(verbosity=2)
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 8ac1df5..050c2dd 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -373,8 +373,15 @@ class PySparkTestCase(unittest.TestCase):
class ReusedPySparkTestCase(unittest.TestCase):
@classmethod
+ def conf(cls):
+ """
+ Override this in subclasses to supply a more specific conf
+ """
+ return SparkConf()
+
+ @classmethod
def setUpClass(cls):
- cls.sc = SparkContext('local[4]', cls.__name__)
+ cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf())
@classmethod
def tearDownClass(cls):
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index e934da4..974344f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -324,16 +324,34 @@ def main(infile, outfile):
importlib.invalidate_caches()
# fetch names and values of broadcast variables
+ needs_broadcast_decryption_server = read_bool(infile)
num_broadcast_variables = read_int(infile)
+ if needs_broadcast_decryption_server:
+ # read the decrypted data from a server in the jvm
+ port = read_int(infile)
+ auth_secret = utf8_deserializer.loads(infile)
+ (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)
+
for _ in range(num_broadcast_variables):
bid = read_long(infile)
if bid >= 0:
- path = utf8_deserializer.loads(infile)
- _broadcastRegistry[bid] = Broadcast(path=path)
+ if needs_broadcast_decryption_server:
+ read_bid = read_long(broadcast_sock_file)
+ assert(read_bid == bid)
+ _broadcastRegistry[bid] = \
+ Broadcast(sock_file=broadcast_sock_file)
+ else:
+ path = utf8_deserializer.loads(infile)
+ _broadcastRegistry[bid] = Broadcast(path=path)
+
else:
bid = - bid - 1
_broadcastRegistry.pop(bid)
+ if needs_broadcast_decryption_server:
+ broadcast_sock_file.write(b'1')
+ broadcast_sock_file.close()
+
_accumulatorRegistry.clear()
eval_type = read_int(infile)
if eval_type == PythonEvalType.NON_UDF:
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index c0830e7..482e2bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -17,6 +17,12 @@
package org.apache.spark.sql.api.python
+import java.io.InputStream
+import java.nio.channels.Channels
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.PythonRDDServer
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
@@ -33,19 +39,36 @@ private[sql] object PythonSQLUtils {
}
/**
- * Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
+ * Python callable function to read a file in Arrow stream format and create a [[RDD]]
* using each serialized ArrowRecordBatch as a partition.
- *
- * @param sqlContext The active [[SQLContext]].
- * @param filename File to read the Arrow stream from.
- * @param schemaString JSON Formatted Spark schema for Arrow batches.
- * @return A new [[DataFrame]].
*/
- def arrowReadStreamFromFile(
- sqlContext: SQLContext,
- filename: String,
- schemaString: String): DataFrame = {
- val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename)
- ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext)
+ def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): JavaRDD[Array[Byte]] = {
+ ArrowConverters.readArrowStreamFromFile(sqlContext, filename)
+ }
+
+ /**
+ * Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
+ * from an RDD.
+ */
+ def toDataFrame(
+ arrowBatchRDD: JavaRDD[Array[Byte]],
+ schemaString: String,
+ sqlContext: SQLContext): DataFrame = {
+ ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, sqlContext)
}
}
+
+/**
+ * Helper for making a dataframe from arrow data from data sent from python over a socket. This is
+ * used when encryption is enabled, and we don't want to write data to a file.
+ */
+private[sql] class ArrowRDDServer(sqlContext: SQLContext) extends PythonRDDServer {
+
+ override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
+ // Create array to consume iterator so that we can safely close the inputStream
+ val batches = ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray
+ // Parallelize the record batches to create an RDD
+ JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length))
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/58419b92/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 1a48bc8..2bf6a58 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.arrow
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream}
-import java.nio.channels.{Channels, SeekableByteChannel}
+import java.nio.channels.{Channels, ReadableByteChannel}
import scala.collection.JavaConverters._
@@ -31,6 +31,7 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer}
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
@@ -189,7 +190,7 @@ private[sql] object ArrowConverters {
}
/**
- * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+ * Create a DataFrame from an RDD of serialized ArrowRecordBatches.
*/
private[sql] def toDataFrame(
arrowBatchRDD: JavaRDD[Array[Byte]],
@@ -221,7 +222,7 @@ private[sql] object ArrowConverters {
/**
* Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches.
*/
- private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = {
+ private[sql] def getBatchesFromStream(in: ReadableByteChannel): Iterator[Array[Byte]] = {
// Iterate over the serialized Arrow RecordBatch messages from a stream
new Iterator[Array[Byte]] {
@@ -271,7 +272,7 @@ private[sql] object ArrowConverters {
} else {
if (bodyLength > 0) {
// Skip message body if not a RecordBatch
- in.position(in.position() + bodyLength)
+ Channels.newInputStream(in).skip(bodyLength)
}
// Proceed to next message
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org