You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2017/09/27 14:21:50 UTC

spark git commit: [SPARK-22125][PYSPARK][SQL] Enable Arrow Stream format for vectorized UDF.

Repository: spark
Updated Branches:
  refs/heads/master 12e740bba -> 09cbf3df2


[SPARK-22125][PYSPARK][SQL] Enable Arrow Stream format for vectorized UDF.

## What changes were proposed in this pull request?

Currently we use Arrow File format to communicate with Python worker when invoking vectorized UDF but we can use Arrow Stream format.

This pr replaces the Arrow File format with the Arrow Stream format.

## How was this patch tested?

Existing tests.

Author: Takuya UESHIN <ue...@databricks.com>

Closes #19349 from ueshin/issues/SPARK-22125.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/09cbf3df
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/09cbf3df
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/09cbf3df

Branch: refs/heads/master
Commit: 09cbf3df20efea09c0941499249b7a3b2bf7e9fd
Parents: 12e740b
Author: Takuya UESHIN <ue...@databricks.com>
Authored: Wed Sep 27 23:21:44 2017 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Wed Sep 27 23:21:44 2017 +0900

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 325 +-------------
 .../apache/spark/api/python/PythonRunner.scala  | 441 +++++++++++++++++++
 python/pyspark/serializers.py                   |  70 +--
 python/pyspark/worker.py                        |   4 +-
 .../sql/execution/vectorized/ColumnarBatch.java |   5 +
 .../execution/python/ArrowEvalPythonExec.scala  |  54 ++-
 .../execution/python/ArrowPythonRunner.scala    | 181 ++++++++
 .../execution/python/BatchEvalPythonExec.scala  |   4 +-
 .../sql/execution/python/PythonUDFRunner.scala  | 113 +++++
 9 files changed, 825 insertions(+), 372 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/09cbf3df/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 86d0405..f6293c0 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -48,7 +48,7 @@ private[spark] class PythonRDD(
   extends RDD[Array[Byte]](parent) {
 
   val bufferSize = conf.getInt("spark.buffer.size", 65536)
-  val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
+  val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true)
 
   override def getPartitions: Array[Partition] = firstParent.partitions
 
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
   val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
 
   override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
-    val runner = PythonRunner(func, bufferSize, reuse_worker)
+    val runner = PythonRunner(func, bufferSize, reuseWorker)
     runner.compute(firstParent.iterator(split, context), split.index, context)
   }
 }
@@ -83,318 +83,9 @@ private[spark] case class PythonFunction(
  */
 private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
 
-/**
- * Enumerate the type of command that will be sent to the Python worker
- */
-private[spark] object PythonEvalType {
-  val NON_UDF = 0
-  val SQL_BATCHED_UDF = 1
-  val SQL_PANDAS_UDF = 2
-}
-
-private[spark] object PythonRunner {
-  def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
-    new PythonRunner(
-      Seq(ChainedPythonFunctions(Seq(func))),
-      bufferSize,
-      reuse_worker,
-      PythonEvalType.NON_UDF,
-      Array(Array(0)))
-  }
-}
-
-/**
- * A helper class to run Python mapPartition/UDFs in Spark.
- *
- * funcs is a list of independent Python functions, each one of them is a list of chained Python
- * functions (from bottom to top).
- */
-private[spark] class PythonRunner(
-    funcs: Seq[ChainedPythonFunctions],
-    bufferSize: Int,
-    reuse_worker: Boolean,
-    evalType: Int,
-    argOffsets: Array[Array[Int]])
-  extends Logging {
-
-  require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
-
-  // All the Python functions should have the same exec, version and envvars.
-  private val envVars = funcs.head.funcs.head.envVars
-  private val pythonExec = funcs.head.funcs.head.pythonExec
-  private val pythonVer = funcs.head.funcs.head.pythonVer
-
-  // TODO: support accumulator in multiple UDF
-  private val accumulator = funcs.head.funcs.head.accumulator
-
-  def compute(
-      inputIterator: Iterator[_],
-      partitionIndex: Int,
-      context: TaskContext): Iterator[Array[Byte]] = {
-    val startTime = System.currentTimeMillis
-    val env = SparkEnv.get
-    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
-    envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread
-    if (reuse_worker) {
-      envVars.put("SPARK_REUSE_WORKER", "1")
-    }
-    val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
-    // Whether is the worker released into idle pool
-    @volatile var released = false
-
-    // Start a thread to feed the process input from our parent's iterator
-    val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)
-
-    context.addTaskCompletionListener { context =>
-      writerThread.shutdownOnTaskCompletion()
-      if (!reuse_worker || !released) {
-        try {
-          worker.close()
-        } catch {
-          case e: Exception =>
-            logWarning("Failed to close worker socket", e)
-        }
-      }
-    }
-
-    writerThread.start()
-    new MonitorThread(env, worker, context).start()
-
-    // Return an iterator that read lines from the process's stdout
-    val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
-    val stdoutIterator = new Iterator[Array[Byte]] {
-      override def next(): Array[Byte] = {
-        val obj = _nextObj
-        if (hasNext) {
-          _nextObj = read()
-        }
-        obj
-      }
-
-      private def read(): Array[Byte] = {
-        if (writerThread.exception.isDefined) {
-          throw writerThread.exception.get
-        }
-        try {
-          stream.readInt() match {
-            case length if length > 0 =>
-              val obj = new Array[Byte](length)
-              stream.readFully(obj)
-              obj
-            case 0 => Array.empty[Byte]
-            case SpecialLengths.TIMING_DATA =>
-              // Timing data from worker
-              val bootTime = stream.readLong()
-              val initTime = stream.readLong()
-              val finishTime = stream.readLong()
-              val boot = bootTime - startTime
-              val init = initTime - bootTime
-              val finish = finishTime - initTime
-              val total = finishTime - startTime
-              logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
-                init, finish))
-              val memoryBytesSpilled = stream.readLong()
-              val diskBytesSpilled = stream.readLong()
-              context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
-              context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
-              read()
-            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
-              // Signals that an exception has been thrown in python
-              val exLength = stream.readInt()
-              val obj = new Array[Byte](exLength)
-              stream.readFully(obj)
-              throw new PythonException(new String(obj, StandardCharsets.UTF_8),
-                writerThread.exception.getOrElse(null))
-            case SpecialLengths.END_OF_DATA_SECTION =>
-              // We've finished the data section of the output, but we can still
-              // read some accumulator updates:
-              val numAccumulatorUpdates = stream.readInt()
-              (1 to numAccumulatorUpdates).foreach { _ =>
-                val updateLen = stream.readInt()
-                val update = new Array[Byte](updateLen)
-                stream.readFully(update)
-                accumulator.add(update)
-              }
-              // Check whether the worker is ready to be re-used.
-              if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
-                if (reuse_worker) {
-                  env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
-                  released = true
-                }
-              }
-              null
-          }
-        } catch {
-
-          case e: Exception if context.isInterrupted =>
-            logDebug("Exception thrown after task interruption", e)
-            throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
-
-          case e: Exception if env.isStopped =>
-            logDebug("Exception thrown after context is stopped", e)
-            null  // exit silently
-
-          case e: Exception if writerThread.exception.isDefined =>
-            logError("Python worker exited unexpectedly (crashed)", e)
-            logError("This may have been caused by a prior exception:", writerThread.exception.get)
-            throw writerThread.exception.get
-
-          case eof: EOFException =>
-            throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
-        }
-      }
-
-      var _nextObj = read()
-
-      override def hasNext: Boolean = _nextObj != null
-    }
-    new InterruptibleIterator(context, stdoutIterator)
-  }
-
-  /**
-   * The thread responsible for writing the data from the PythonRDD's parent iterator to the
-   * Python process.
-   */
-  class WriterThread(
-      env: SparkEnv,
-      worker: Socket,
-      inputIterator: Iterator[_],
-      partitionIndex: Int,
-      context: TaskContext)
-    extends Thread(s"stdout writer for $pythonExec") {
-
-    @volatile private var _exception: Exception = null
-
-    private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
-    private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
-
-    setDaemon(true)
-
-    /** Contains the exception thrown while writing the parent iterator to the Python process. */
-    def exception: Option[Exception] = Option(_exception)
-
-    /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
-    def shutdownOnTaskCompletion() {
-      assert(context.isCompleted)
-      this.interrupt()
-    }
-
-    override def run(): Unit = Utils.logUncaughtExceptions {
-      try {
-        TaskContext.setTaskContext(context)
-        val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
-        val dataOut = new DataOutputStream(stream)
-        // Partition index
-        dataOut.writeInt(partitionIndex)
-        // Python version of driver
-        PythonRDD.writeUTF(pythonVer, dataOut)
-        // Write out the TaskContextInfo
-        dataOut.writeInt(context.stageId())
-        dataOut.writeInt(context.partitionId())
-        dataOut.writeInt(context.attemptNumber())
-        dataOut.writeLong(context.taskAttemptId())
-        // sparkFilesDir
-        PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
-        // Python includes (*.zip and *.egg files)
-        dataOut.writeInt(pythonIncludes.size)
-        for (include <- pythonIncludes) {
-          PythonRDD.writeUTF(include, dataOut)
-        }
-        // Broadcast variables
-        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
-        val newBids = broadcastVars.map(_.id).toSet
-        // number of different broadcasts
-        val toRemove = oldBids.diff(newBids)
-        val cnt = toRemove.size + newBids.diff(oldBids).size
-        dataOut.writeInt(cnt)
-        for (bid <- toRemove) {
-          // remove the broadcast from worker
-          dataOut.writeLong(- bid - 1)  // bid >= 0
-          oldBids.remove(bid)
-        }
-        for (broadcast <- broadcastVars) {
-          if (!oldBids.contains(broadcast.id)) {
-            // send new broadcast
-            dataOut.writeLong(broadcast.id)
-            PythonRDD.writeUTF(broadcast.value.path, dataOut)
-            oldBids.add(broadcast.id)
-          }
-        }
-        dataOut.flush()
-        // Serialized command:
-        dataOut.writeInt(evalType)
-        if (evalType != PythonEvalType.NON_UDF) {
-          dataOut.writeInt(funcs.length)
-          funcs.zip(argOffsets).foreach { case (chained, offsets) =>
-            dataOut.writeInt(offsets.length)
-            offsets.foreach { offset =>
-              dataOut.writeInt(offset)
-            }
-            dataOut.writeInt(chained.funcs.length)
-            chained.funcs.foreach { f =>
-              dataOut.writeInt(f.command.length)
-              dataOut.write(f.command)
-            }
-          }
-        } else {
-          val command = funcs.head.funcs.head.command
-          dataOut.writeInt(command.length)
-          dataOut.write(command)
-        }
-        // Data values
-        PythonRDD.writeIteratorToStream(inputIterator, dataOut)
-        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
-        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
-        dataOut.flush()
-      } catch {
-        case e: Exception if context.isCompleted || context.isInterrupted =>
-          logDebug("Exception thrown after task completion (likely due to cleanup)", e)
-          if (!worker.isClosed) {
-            Utils.tryLog(worker.shutdownOutput())
-          }
-
-        case e: Exception =>
-          // We must avoid throwing exceptions here, because the thread uncaught exception handler
-          // will kill the whole executor (see org.apache.spark.executor.Executor).
-          _exception = e
-          if (!worker.isClosed) {
-            Utils.tryLog(worker.shutdownOutput())
-          }
-      }
-    }
-  }
-
-  /**
-   * It is necessary to have a monitor thread for python workers if the user cancels with
-   * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
-   * threads can block indefinitely.
-   */
-  class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
-    extends Thread(s"Worker Monitor for $pythonExec") {
-
-    setDaemon(true)
-
-    override def run() {
-      // Kill the worker if it is interrupted, checking until task completion.
-      // TODO: This has a race condition if interruption occurs, as completed may still become true.
-      while (!context.isInterrupted && !context.isCompleted) {
-        Thread.sleep(2000)
-      }
-      if (!context.isCompleted) {
-        try {
-          logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
-          env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
-        } catch {
-          case e: Exception =>
-            logError("Exception when trying to kill worker", e)
-        }
-      }
-    }
-  }
-}
-
 /** Thrown for exceptions in user Python code. */
-private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause)
+private[spark] class PythonException(msg: String, cause: Exception)
+  extends RuntimeException(msg, cause)
 
 /**
  * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
@@ -411,14 +102,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte]
   val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
 }
 
-private object SpecialLengths {
-  val END_OF_DATA_SECTION = -1
-  val PYTHON_EXCEPTION_THROWN = -2
-  val TIMING_DATA = -3
-  val END_OF_STREAM = -4
-  val NULL = -5
-}
-
 private[spark] object PythonRDD extends Logging {
 
   // remember the broadcasts sent to each worker

http://git-wip-us.apache.org/repos/asf/spark/blob/09cbf3df/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
new file mode 100644
index 0000000..3688a14
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -0,0 +1,441 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.io._
+import java.net._
+import java.nio.charset.StandardCharsets
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark._
+import org.apache.spark.internal.Logging
+import org.apache.spark.util._
+
+
+/**
+ * Enumerate the type of command that will be sent to the Python worker
+ */
+private[spark] object PythonEvalType {
+  val NON_UDF = 0
+  val SQL_BATCHED_UDF = 1
+  val SQL_PANDAS_UDF = 2
+}
+
+/**
+ * A helper class to run Python mapPartition/UDFs in Spark.
+ *
+ * funcs is a list of independent Python functions, each one of them is a list of chained Python
+ * functions (from bottom to top).
+ */
+private[spark] abstract class BasePythonRunner[IN, OUT](
+    funcs: Seq[ChainedPythonFunctions],
+    bufferSize: Int,
+    reuseWorker: Boolean,
+    evalType: Int,
+    argOffsets: Array[Array[Int]])
+  extends Logging {
+
+  require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
+
+  // All the Python functions should have the same exec, version and envvars.
+  protected val envVars = funcs.head.funcs.head.envVars
+  protected val pythonExec = funcs.head.funcs.head.pythonExec
+  protected val pythonVer = funcs.head.funcs.head.pythonVer
+
+  // TODO: support accumulator in multiple UDF
+  protected val accumulator = funcs.head.funcs.head.accumulator
+
+  def compute(
+      inputIterator: Iterator[IN],
+      partitionIndex: Int,
+      context: TaskContext): Iterator[OUT] = {
+    val startTime = System.currentTimeMillis
+    val env = SparkEnv.get
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
+    envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread
+    if (reuseWorker) {
+      envVars.put("SPARK_REUSE_WORKER", "1")
+    }
+    val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
+    // Whether is the worker released into idle pool
+    val released = new AtomicBoolean(false)
+
+    // Start a thread to feed the process input from our parent's iterator
+    val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context)
+
+    context.addTaskCompletionListener { _ =>
+      writerThread.shutdownOnTaskCompletion()
+      if (!reuseWorker || !released.get) {
+        try {
+          worker.close()
+        } catch {
+          case e: Exception =>
+            logWarning("Failed to close worker socket", e)
+        }
+      }
+    }
+
+    writerThread.start()
+    new MonitorThread(env, worker, context).start()
+
+    // Return an iterator that read lines from the process's stdout
+    val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
+
+    val stdoutIterator = newReaderIterator(
+      stream, writerThread, startTime, env, worker, released, context)
+    new InterruptibleIterator(context, stdoutIterator)
+  }
+
+  protected def newWriterThread(
+      env: SparkEnv,
+      worker: Socket,
+      inputIterator: Iterator[IN],
+      partitionIndex: Int,
+      context: TaskContext): WriterThread
+
+  protected def newReaderIterator(
+      stream: DataInputStream,
+      writerThread: WriterThread,
+      startTime: Long,
+      env: SparkEnv,
+      worker: Socket,
+      released: AtomicBoolean,
+      context: TaskContext): Iterator[OUT]
+
+  /**
+   * The thread responsible for writing the data from the PythonRDD's parent iterator to the
+   * Python process.
+   */
+  abstract class WriterThread(
+      env: SparkEnv,
+      worker: Socket,
+      inputIterator: Iterator[IN],
+      partitionIndex: Int,
+      context: TaskContext)
+    extends Thread(s"stdout writer for $pythonExec") {
+
+    @volatile private var _exception: Exception = null
+
+    private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
+    private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
+
+    setDaemon(true)
+
+    /** Contains the exception thrown while writing the parent iterator to the Python process. */
+    def exception: Option[Exception] = Option(_exception)
+
+    /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
+    def shutdownOnTaskCompletion() {
+      assert(context.isCompleted)
+      this.interrupt()
+    }
+
+    /**
+     * Writes a command section to the stream connected to the Python worker.
+     */
+    protected def writeCommand(dataOut: DataOutputStream): Unit
+
+    /**
+     * Writes input data to the stream connected to the Python worker.
+     */
+    protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
+
+    override def run(): Unit = Utils.logUncaughtExceptions {
+      try {
+        TaskContext.setTaskContext(context)
+        val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+        val dataOut = new DataOutputStream(stream)
+        // Partition index
+        dataOut.writeInt(partitionIndex)
+        // Python version of driver
+        PythonRDD.writeUTF(pythonVer, dataOut)
+        // Write out the TaskContextInfo
+        dataOut.writeInt(context.stageId())
+        dataOut.writeInt(context.partitionId())
+        dataOut.writeInt(context.attemptNumber())
+        dataOut.writeLong(context.taskAttemptId())
+        // sparkFilesDir
+        PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
+        // Python includes (*.zip and *.egg files)
+        dataOut.writeInt(pythonIncludes.size)
+        for (include <- pythonIncludes) {
+          PythonRDD.writeUTF(include, dataOut)
+        }
+        // Broadcast variables
+        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+        val newBids = broadcastVars.map(_.id).toSet
+        // number of different broadcasts
+        val toRemove = oldBids.diff(newBids)
+        val cnt = toRemove.size + newBids.diff(oldBids).size
+        dataOut.writeInt(cnt)
+        for (bid <- toRemove) {
+          // remove the broadcast from worker
+          dataOut.writeLong(- bid - 1)  // bid >= 0
+          oldBids.remove(bid)
+        }
+        for (broadcast <- broadcastVars) {
+          if (!oldBids.contains(broadcast.id)) {
+            // send new broadcast
+            dataOut.writeLong(broadcast.id)
+            PythonRDD.writeUTF(broadcast.value.path, dataOut)
+            oldBids.add(broadcast.id)
+          }
+        }
+        dataOut.flush()
+
+        dataOut.writeInt(evalType)
+        writeCommand(dataOut)
+        writeIteratorToStream(dataOut)
+
+        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
+        dataOut.flush()
+      } catch {
+        case e: Exception if context.isCompleted || context.isInterrupted =>
+          logDebug("Exception thrown after task completion (likely due to cleanup)", e)
+          if (!worker.isClosed) {
+            Utils.tryLog(worker.shutdownOutput())
+          }
+
+        case e: Exception =>
+          // We must avoid throwing exceptions here, because the thread uncaught exception handler
+          // will kill the whole executor (see org.apache.spark.executor.Executor).
+          _exception = e
+          if (!worker.isClosed) {
+            Utils.tryLog(worker.shutdownOutput())
+          }
+      }
+    }
+  }
+
+  abstract class ReaderIterator(
+      stream: DataInputStream,
+      writerThread: WriterThread,
+      startTime: Long,
+      env: SparkEnv,
+      worker: Socket,
+      released: AtomicBoolean,
+      context: TaskContext)
+    extends Iterator[OUT] {
+
+    private var nextObj: OUT = _
+    private var eos = false
+
+    override def hasNext: Boolean = nextObj != null || {
+      if (!eos) {
+        nextObj = read()
+        hasNext
+      } else {
+        false
+      }
+    }
+
+    override def next(): OUT = {
+      if (hasNext) {
+        val obj = nextObj
+        nextObj = null.asInstanceOf[OUT]
+        obj
+      } else {
+        Iterator.empty.next()
+      }
+    }
+
+    /**
+     * Reads next object from the stream.
+     * When the stream reaches end of data, needs to process the following sections,
+     * and then returns null.
+     */
+    protected def read(): OUT
+
+    protected def handleTimingData(): Unit = {
+      // Timing data from worker
+      val bootTime = stream.readLong()
+      val initTime = stream.readLong()
+      val finishTime = stream.readLong()
+      val boot = bootTime - startTime
+      val init = initTime - bootTime
+      val finish = finishTime - initTime
+      val total = finishTime - startTime
+      logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
+        init, finish))
+      val memoryBytesSpilled = stream.readLong()
+      val diskBytesSpilled = stream.readLong()
+      context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
+      context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
+    }
+
+    protected def handlePythonException(): PythonException = {
+      // Signals that an exception has been thrown in python
+      val exLength = stream.readInt()
+      val obj = new Array[Byte](exLength)
+      stream.readFully(obj)
+      new PythonException(new String(obj, StandardCharsets.UTF_8),
+        writerThread.exception.getOrElse(null))
+    }
+
+    protected def handleEndOfDataSection(): Unit = {
+      // We've finished the data section of the output, but we can still
+      // read some accumulator updates:
+      val numAccumulatorUpdates = stream.readInt()
+      (1 to numAccumulatorUpdates).foreach { _ =>
+        val updateLen = stream.readInt()
+        val update = new Array[Byte](updateLen)
+        stream.readFully(update)
+        accumulator.add(update)
+      }
+      // Check whether the worker is ready to be re-used.
+      if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
+        if (reuseWorker) {
+          env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
+          released.set(true)
+        }
+      }
+      eos = true
+    }
+
+    protected val handleException: PartialFunction[Throwable, OUT] = {
+      case e: Exception if context.isInterrupted =>
+        logDebug("Exception thrown after task interruption", e)
+        throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
+
+      case e: Exception if env.isStopped =>
+        logDebug("Exception thrown after context is stopped", e)
+        null.asInstanceOf[OUT]  // exit silently
+
+      case e: Exception if writerThread.exception.isDefined =>
+        logError("Python worker exited unexpectedly (crashed)", e)
+        logError("This may have been caused by a prior exception:", writerThread.exception.get)
+        throw writerThread.exception.get
+
+      case eof: EOFException =>
+        throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
+    }
+  }
+
+  /**
+   * It is necessary to have a monitor thread for python workers if the user cancels with
+   * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
+   * threads can block indefinitely.
+   */
+  class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
+    extends Thread(s"Worker Monitor for $pythonExec") {
+
+    setDaemon(true)
+
+    override def run() {
+      // Kill the worker if it is interrupted, checking until task completion.
+      // TODO: This has a race condition if interruption occurs, as completed may still become true.
+      while (!context.isInterrupted && !context.isCompleted) {
+        Thread.sleep(2000)
+      }
+      if (!context.isCompleted) {
+        try {
+          logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
+          env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
+        } catch {
+          case e: Exception =>
+            logError("Exception when trying to kill worker", e)
+        }
+      }
+    }
+  }
+}
+
+private[spark] object PythonRunner {
+
+  def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonRunner = {
+    new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker)
+  }
+}
+
+/**
+ * A helper class to run Python mapPartition in Spark.
+ */
+private[spark] class PythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    bufferSize: Int,
+    reuseWorker: Boolean)
+  extends BasePythonRunner[Array[Byte], Array[Byte]](
+    funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) {
+
+  protected override def newWriterThread(
+      env: SparkEnv,
+      worker: Socket,
+      inputIterator: Iterator[Array[Byte]],
+      partitionIndex: Int,
+      context: TaskContext): WriterThread = {
+    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
+
+      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+        val command = funcs.head.funcs.head.command
+        dataOut.writeInt(command.length)
+        dataOut.write(command)
+      }
+
+      protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
+        PythonRDD.writeIteratorToStream(inputIterator, dataOut)
+        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+      }
+    }
+  }
+
+  protected override def newReaderIterator(
+      stream: DataInputStream,
+      writerThread: WriterThread,
+      startTime: Long,
+      env: SparkEnv,
+      worker: Socket,
+      released: AtomicBoolean,
+      context: TaskContext): Iterator[Array[Byte]] = {
+    new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) {
+
+      protected override def read(): Array[Byte] = {
+        if (writerThread.exception.isDefined) {
+          throw writerThread.exception.get
+        }
+        try {
+          stream.readInt() match {
+            case length if length > 0 =>
+              val obj = new Array[Byte](length)
+              stream.readFully(obj)
+              obj
+            case 0 => Array.empty[Byte]
+            case SpecialLengths.TIMING_DATA =>
+              handleTimingData()
+              read()
+            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+              throw handlePythonException()
+            case SpecialLengths.END_OF_DATA_SECTION =>
+              handleEndOfDataSection()
+              null
+          }
+        } catch handleException
+      }
+    }
+  }
+}
+
+private[spark] object SpecialLengths {
+  val END_OF_DATA_SECTION = -1
+  val PYTHON_EXCEPTION_THROWN = -2
+  val TIMING_DATA = -3
+  val END_OF_STREAM = -4
+  val NULL = -5
+  val START_ARROW_STREAM = -6
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/09cbf3df/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 7c1fbad..db77b7e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -79,6 +79,7 @@ class SpecialLengths(object):
     TIMING_DATA = -3
     END_OF_STREAM = -4
     NULL = -5
+    START_ARROW_STREAM = -6
 
 
 class PythonEvalType(object):
@@ -211,44 +212,61 @@ class ArrowSerializer(FramedSerializer):
         return "ArrowSerializer"
 
 
-class ArrowPandasSerializer(ArrowSerializer):
+def _create_batch(series):
+    import pyarrow as pa
+    # Make input conform to [(series1, type1), (series2, type2), ...]
+    if not isinstance(series, (list, tuple)) or \
+            (len(series) == 2 and isinstance(series[1], pa.DataType)):
+        series = [series]
+    series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
+
+    # If a nullable integer series has been promoted to floating point with NaNs, need to cast
+    # NOTE: this is not necessary with Arrow >= 0.7
+    def cast_series(s, t):
+        if t is None or s.dtype == t.to_pandas_dtype():
+            return s
+        else:
+            return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)
+
+    arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series]
+    return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
+
+
+class ArrowStreamPandasSerializer(Serializer):
     """
-    Serializes Pandas.Series as Arrow data.
+    Serializes Pandas.Series as Arrow data with Arrow streaming format.
     """
 
-    def dumps(self, series):
+    def dump_stream(self, iterator, stream):
         """
-        Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or
+        Make ArrowRecordBatches from Pandas Serieses and serialize. Input is a single series or
         a list of series accompanied by an optional pyarrow type to coerce the data to.
         """
         import pyarrow as pa
-        # Make input conform to [(series1, type1), (series2, type2), ...]
-        if not isinstance(series, (list, tuple)) or \
-                (len(series) == 2 and isinstance(series[1], pa.DataType)):
-            series = [series]
-        series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
-
-        # If a nullable integer series has been promoted to floating point with NaNs, need to cast
-        # NOTE: this is not necessary with Arrow >= 0.7
-        def cast_series(s, t):
-            if t is None or s.dtype == t.to_pandas_dtype():
-                return s
-            else:
-                return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)
-
-        arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series]
-        batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
-        return super(ArrowPandasSerializer, self).dumps(batch)
+        writer = None
+        try:
+            for series in iterator:
+                batch = _create_batch(series)
+                if writer is None:
+                    write_int(SpecialLengths.START_ARROW_STREAM, stream)
+                    writer = pa.RecordBatchStreamWriter(stream, batch.schema)
+                writer.write_batch(batch)
+        finally:
+            if writer is not None:
+                writer.close()
 
-    def loads(self, obj):
+    def load_stream(self, stream):
         """
-        Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series.
+        Deserialize ArrowRecordBatchs to an Arrow table and return as a list of pandas.Series.
         """
-        table = super(ArrowPandasSerializer, self).loads(obj)
-        return [c.to_pandas() for c in table.itercolumns()]
+        import pyarrow as pa
+        reader = pa.open_stream(stream)
+        for batch in reader:
+            table = pa.Table.from_batches([batch])
+            yield [c.to_pandas() for c in table.itercolumns()]
 
     def __repr__(self):
-        return "ArrowPandasSerializer"
+        return "ArrowStreamPandasSerializer"
 
 
 class BatchedSerializer(Serializer):

http://git-wip-us.apache.org/repos/asf/spark/blob/09cbf3df/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index fd917c4..4e24789 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,7 +31,7 @@ from pyspark.taskcontext import TaskContext
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, write_int, read_long, \
     write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
-    BatchedSerializer, ArrowPandasSerializer
+    BatchedSerializer, ArrowStreamPandasSerializer
 from pyspark.sql.types import toArrowType
 from pyspark import shuffle
 
@@ -123,7 +123,7 @@ def read_udfs(pickleSer, infile, eval_type):
     func = lambda _, it: map(mapper, it)
 
     if eval_type == PythonEvalType.SQL_PANDAS_UDF:
-        ser = ArrowPandasSerializer()
+        ser = ArrowStreamPandasSerializer()
     else:
         ser = BatchedSerializer(PickleSerializer(), 100)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/09cbf3df/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index e782756..bc546c7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -463,6 +463,11 @@ public final class ColumnarBatch {
   }
 
   /**
+   * Returns the schema that makes up this batch.
+   */
+  public StructType schema() { return schema; }
+
+  /**
    * Returns the max capacity (in number of rows) for this batch.
    */
   public int capacity() { return capacity; }

http://git-wip-us.apache.org/repos/asf/spark/blob/09cbf3df/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 5e72cd2..f7e8cbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.sql.execution.python
 
+import scala.collection.JavaConverters._
+
 import org.apache.spark.TaskContext
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -39,25 +40,36 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
       iter: Iterator[InternalRow],
       schema: StructType,
       context: TaskContext): Iterator[InternalRow] = {
-    val inputIterator = ArrowConverters.toPayloadIterator(
-      iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable)
-
-    // Output iterator for results from Python.
-    val outputIterator = new PythonRunner(
-        funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets)
-      .compute(inputIterator, context.partitionId(), context)
-
-    val outputRowIterator = ArrowConverters.fromPayloadIterator(
-      outputIterator.map(new ArrowPayload(_)), context)
-
-    // Verify that the output schema is correct
-    if (outputRowIterator.hasNext) {
-      val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
-        .map { case (attr, i) => attr.withName(s"_$i") })
-      assert(schemaOut.equals(outputRowIterator.schema),
-        s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")
-    }
 
-    outputRowIterator
+    val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
+      .map { case (attr, i) => attr.withName(s"_$i") })
+
+    val columnarBatchIter = new ArrowPythonRunner(
+        funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker,
+        PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema)
+      .compute(iter, context.partitionId(), context)
+
+    new Iterator[InternalRow] {
+
+      var currentIter = if (columnarBatchIter.hasNext) {
+        val batch = columnarBatchIter.next()
+        assert(schemaOut.equals(batch.schema),
+          s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}")
+        batch.rowIterator.asScala
+      } else {
+        Iterator.empty
+      }
+
+      override def hasNext: Boolean = currentIter.hasNext || {
+        if (columnarBatchIter.hasNext) {
+          currentIter = columnarBatchIter.next().rowIterator.asScala
+          hasNext
+        } else {
+          false
+        }
+      }
+
+      override def next(): InternalRow = currentIter.next()
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/09cbf3df/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
new file mode 100644
index 0000000..bbad9d6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+import java.net._
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter}
+
+import org.apache.spark._
+import org.apache.spark.api.python._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter}
+import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+/**
+ * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream.
+ */
+class ArrowPythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    batchSize: Int,
+    bufferSize: Int,
+    reuseWorker: Boolean,
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    schema: StructType)
+  extends BasePythonRunner[InternalRow, ColumnarBatch](
+    funcs, bufferSize, reuseWorker, evalType, argOffsets) {
+
+  protected override def newWriterThread(
+      env: SparkEnv,
+      worker: Socket,
+      inputIterator: Iterator[InternalRow],
+      partitionIndex: Int,
+      context: TaskContext): WriterThread = {
+    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
+
+      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+        PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+      }
+
+      protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
+        val arrowSchema = ArrowUtils.toArrowSchema(schema)
+        val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+          s"stdout writer for $pythonExec", 0, Long.MaxValue)
+
+        val root = VectorSchemaRoot.create(arrowSchema, allocator)
+        val arrowWriter = ArrowWriter.create(root)
+
+        var closed = false
+
+        context.addTaskCompletionListener { _ =>
+          if (!closed) {
+            root.close()
+            allocator.close()
+          }
+        }
+
+        val writer = new ArrowStreamWriter(root, null, dataOut)
+        writer.start()
+
+        Utils.tryWithSafeFinally {
+          while (inputIterator.hasNext) {
+            var rowCount = 0
+            while (inputIterator.hasNext && (batchSize <= 0 || rowCount < batchSize)) {
+              val row = inputIterator.next()
+              arrowWriter.write(row)
+              rowCount += 1
+            }
+            arrowWriter.finish()
+            writer.writeBatch()
+            arrowWriter.reset()
+          }
+        } {
+          writer.end()
+          root.close()
+          allocator.close()
+          closed = true
+        }
+      }
+    }
+  }
+
+  protected override def newReaderIterator(
+      stream: DataInputStream,
+      writerThread: WriterThread,
+      startTime: Long,
+      env: SparkEnv,
+      worker: Socket,
+      released: AtomicBoolean,
+      context: TaskContext): Iterator[ColumnarBatch] = {
+    new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) {
+
+      private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+        s"stdin reader for $pythonExec", 0, Long.MaxValue)
+
+      private var reader: ArrowStreamReader = _
+      private var root: VectorSchemaRoot = _
+      private var schema: StructType = _
+      private var vectors: Array[ColumnVector] = _
+
+      private var closed = false
+
+      context.addTaskCompletionListener { _ =>
+        // todo: we need something like `reader.end()`, which release all the resources, but leave
+        // the input stream open. `reader.close()` will close the socket and we can't reuse worker.
+        // So here we simply not close the reader, which is problematic.
+        if (!closed) {
+          if (root != null) {
+            root.close()
+          }
+          allocator.close()
+        }
+      }
+
+      private var batchLoaded = true
+
+      protected override def read(): ColumnarBatch = {
+        if (writerThread.exception.isDefined) {
+          throw writerThread.exception.get
+        }
+        try {
+          if (reader != null && batchLoaded) {
+            batchLoaded = reader.loadNextBatch()
+            if (batchLoaded) {
+              val batch = new ColumnarBatch(schema, vectors, root.getRowCount)
+              batch.setNumRows(root.getRowCount)
+              batch
+            } else {
+              root.close()
+              allocator.close()
+              closed = true
+              // Reach end of stream. Call `read()` again to read control data.
+              read()
+            }
+          } else {
+            stream.readInt() match {
+              case SpecialLengths.START_ARROW_STREAM =>
+                reader = new ArrowStreamReader(stream, allocator)
+                root = reader.getVectorSchemaRoot()
+                schema = ArrowUtils.fromArrowSchema(root.getSchema())
+                vectors = root.getFieldVectors().asScala.map { vector =>
+                  new ArrowColumnVector(vector)
+                }.toArray[ColumnVector]
+                read()
+              case SpecialLengths.TIMING_DATA =>
+                handleTimingData()
+                read()
+              case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+                throw handlePythonException()
+              case SpecialLengths.END_OF_DATA_SECTION =>
+                handleEndOfDataSection()
+                null
+            }
+          }
+        } catch handleException
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/09cbf3df/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 2978eac..26ee25f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
 import net.razorvine.pickle.{Pickler, Unpickler}
 
 import org.apache.spark.TaskContext
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
@@ -68,7 +68,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
     }.grouped(100).map(x => pickle.dumps(x.toArray))
 
     // Output iterator for results from Python.
-    val outputIterator = new PythonRunner(
+    val outputIterator = new PythonUDFRunner(
         funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
       .compute(inputIterator, context.partitionId(), context)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/09cbf3df/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
new file mode 100644
index 0000000..e28def1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+import java.net._
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.spark._
+import org.apache.spark.api.python._
+
+/**
+ * A helper class to run Python UDFs in Spark.
+ */
+class PythonUDFRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    bufferSize: Int,
+    reuseWorker: Boolean,
+    evalType: Int,
+    argOffsets: Array[Array[Int]])
+  extends BasePythonRunner[Array[Byte], Array[Byte]](
+    funcs, bufferSize, reuseWorker, evalType, argOffsets) {
+
+  protected override def newWriterThread(
+      env: SparkEnv,
+      worker: Socket,
+      inputIterator: Iterator[Array[Byte]],
+      partitionIndex: Int,
+      context: TaskContext): WriterThread = {
+    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
+
+      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+        PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+      }
+
+      protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
+        PythonRDD.writeIteratorToStream(inputIterator, dataOut)
+        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+      }
+    }
+  }
+
+  protected override def newReaderIterator(
+      stream: DataInputStream,
+      writerThread: WriterThread,
+      startTime: Long,
+      env: SparkEnv,
+      worker: Socket,
+      released: AtomicBoolean,
+      context: TaskContext): Iterator[Array[Byte]] = {
+    new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) {
+
+      protected override def read(): Array[Byte] = {
+        if (writerThread.exception.isDefined) {
+          throw writerThread.exception.get
+        }
+        try {
+          stream.readInt() match {
+            case length if length > 0 =>
+              val obj = new Array[Byte](length)
+              stream.readFully(obj)
+              obj
+            case 0 => Array.empty[Byte]
+            case SpecialLengths.TIMING_DATA =>
+              handleTimingData()
+              read()
+            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+              throw handlePythonException()
+            case SpecialLengths.END_OF_DATA_SECTION =>
+              handleEndOfDataSection()
+              null
+          }
+        } catch handleException
+      }
+    }
+  }
+}
+
+object PythonUDFRunner {
+
+  def writeUDFs(
+      dataOut: DataOutputStream,
+      funcs: Seq[ChainedPythonFunctions],
+      argOffsets: Array[Array[Int]]): Unit = {
+    dataOut.writeInt(funcs.length)
+    funcs.zip(argOffsets).foreach { case (chained, offsets) =>
+      dataOut.writeInt(offsets.length)
+      offsets.foreach { offset =>
+        dataOut.writeInt(offset)
+      }
+      dataOut.writeInt(chained.funcs.length)
+      chained.funcs.foreach { f =>
+        dataOut.writeInt(f.command.length)
+        dataOut.write(f.command)
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org