You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/10/07 20:45:04 UTC

spark git commit: [SPARK-15621][SQL] Support spilling for Python UDF

Repository: spark
Updated Branches:
  refs/heads/master 9d8ae853e -> 2badb58cd


[SPARK-15621][SQL] Support spilling for Python UDF

## What changes were proposed in this pull request?

When execute a Python UDF, we buffer the input row into as queue, then pull them out to join with the result from Python UDF. In the case that Python UDF is slow or the input row is too wide, we could ran out of memory because of the queue. Since we can't flush all the buffers (sockets) between JVM and Python process from JVM side, we can't limit the rows in the queue, otherwise it could deadlock.

This PR will manage the memory used by the queue, spill that into disk when there is no enough memory (also release the memory and disk space as soon as possible).

## How was this patch tested?

Added unit tests. Also manually ran a workload with large input row and slow python UDF (with  large broadcast) like this:

```
b = range(1<<24)
add = udf(lambda x: x + len(b), IntegerType())
df = sqlContext.range(1, 1<<26, 1, 4)
print df.select(df.id, lit("adf"*10000).alias("s"), add(df.id).alias("add")).groupBy(length("s")).sum().collect()
```

It ran out of memory (hang because of full GC) before the patch, ran smoothly after the patch.

Author: Davies Liu <da...@databricks.com>

Closes #15089 from davies/spill_udf.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2badb58c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2badb58c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2badb58c

Branch: refs/heads/master
Commit: 2badb58cdd7833465202197c4c52db5aa3d4c6e7
Parents: 9d8ae85
Author: Davies Liu <da...@databricks.com>
Authored: Fri Oct 7 13:45:00 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Fri Oct 7 13:45:00 2016 -0700

----------------------------------------------------------------------
 .../execution/python/BatchEvalPythonExec.scala  |  36 ++-
 .../spark/sql/execution/python/RowQueue.scala   | 280 +++++++++++++++++++
 .../sql/execution/python/RowQueueSuite.scala    | 127 +++++++++
 3 files changed, 436 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2badb58c/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index d9bf4d3..f9d20ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -17,18 +17,21 @@
 
 package org.apache.spark.sql.execution.python
 
+import java.io.File
+
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
 import net.razorvine.pickle.{Pickler, Unpickler}
 
-import org.apache.spark.TaskContext
+import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.util.Utils
 
 
 /**
@@ -37,9 +40,25 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType}
  * Python evaluation works by sending the necessary (projected) input data via a socket to an
  * external Python process, and combine the result from the Python process with the original row.
  *
- * For each row we send to Python, we also put it in a queue. For each output row from Python,
+ * For each row we send to Python, we also put it in a queue first. For each output row from Python,
  * we drain the queue to find the original input row. Note that if the Python process is way too
- * slow, this could lead to the queue growing unbounded and eventually run out of memory.
+ * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory.
+ *
+ * Here is a diagram to show how this works:
+ *
+ *            Downstream (for parent)
+ *             /      \
+ *            /     socket  (output of UDF)
+ *           /         \
+ *        RowQueue    Python
+ *           \         /
+ *            \     socket  (input of UDF)
+ *             \     /
+ *          upstream (from child)
+ *
+ * The rows sent to and received from Python are packed into batches (100 rows) and serialized,
+ * there should be always some rows buffered in the socket or Python process, so the pulling from
+ * RowQueue ALWAYS happened after pushing into it.
  */
 case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
   extends SparkPlan {
@@ -70,7 +89,11 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
 
       // The queue used to buffer input rows so we can drain it to
       // combine input with output from Python.
-      val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
+      val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(),
+        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+      TaskContext.get().addTaskCompletionListener({ ctx =>
+        queue.close()
+      })
 
       val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
 
@@ -98,7 +121,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
       // For each row, add it to the queue.
       val inputIterator = iter.grouped(100).map { inputRows =>
         val toBePickled = inputRows.map { inputRow =>
-          queue.add(inputRow)
+          queue.add(inputRow.asInstanceOf[UnsafeRow])
           val row = projection(inputRow)
           if (needConversion) {
             EvaluatePython.toJava(row, schema)
@@ -132,7 +155,6 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
         StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
       }
       val resultProj = UnsafeProjection.create(output, output)
-
       outputIterator.flatMap { pickedResult =>
         val unpickledBatch = unpickle.loads(pickedResult)
         unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
@@ -144,7 +166,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
         } else {
           EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
         }
-        resultProj(joined(queue.poll(), row))
+        resultProj(joined(queue.remove(), row))
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2badb58c/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
new file mode 100644
index 0000000..422a3f8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
@@ -0,0 +1,280 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import com.google.common.io.Closeables
+
+import org.apache.spark.SparkException
+import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.memory.MemoryBlock
+
+/**
+ * A RowQueue is an FIFO queue for UnsafeRow.
+ *
+ * This RowQueue is ONLY designed and used for Python UDF, which has only one writer and only one
+ * reader, the reader ALWAYS ran behind the writer. See the doc of class [[BatchEvalPythonExec]]
+ * on how it works.
+ */
+private[python] trait RowQueue {
+
+  /**
+   * Add a row to the end of it, returns true iff the row has been added to the queue.
+   */
+  def add(row: UnsafeRow): Boolean
+
+  /**
+   * Retrieve and remove the first row, returns null if it's empty.
+   *
+   * It can only be called after add is called, otherwise it will fail (NPE).
+   */
+  def remove(): UnsafeRow
+
+  /**
+   * Cleanup all the resources.
+   */
+  def close(): Unit
+}
+
+/**
+ * A RowQueue that is based on in-memory page. UnsafeRows are appended into it until it's full.
+ * Another thread could read from it at the same time (behind the writer).
+ *
+ * The format of UnsafeRow in page:
+ * [4 bytes to hold length of record (N)] [N bytes to hold record] [...]
+ *
+ * -1 length means end of page.
+ */
+private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields: Int)
+  extends RowQueue {
+  private val base: AnyRef = page.getBaseObject
+  private val endOfPage: Long = page.getBaseOffset + page.size
+  // the first location where a new row would be written
+  private var writeOffset = page.getBaseOffset
+  // points to the start of the next row to read
+  private var readOffset = page.getBaseOffset
+  private val resultRow = new UnsafeRow(numFields)
+
+  def add(row: UnsafeRow): Boolean = synchronized {
+    val size = row.getSizeInBytes
+    if (writeOffset + 4 + size > endOfPage) {
+      // if there is not enough space in this page to hold the new record
+      if (writeOffset + 4 <= endOfPage) {
+        // if there's extra space at the end of the page, store a special "end-of-page" length (-1)
+        Platform.putInt(base, writeOffset, -1)
+      }
+      false
+    } else {
+      Platform.putInt(base, writeOffset, size)
+      Platform.copyMemory(row.getBaseObject, row.getBaseOffset, base, writeOffset + 4, size)
+      writeOffset += 4 + size
+      true
+    }
+  }
+
+  def remove(): UnsafeRow = synchronized {
+    assert(readOffset <= writeOffset, "reader should not go beyond writer")
+    if (readOffset + 4 > endOfPage || Platform.getInt(base, readOffset) < 0) {
+      null
+    } else {
+      val size = Platform.getInt(base, readOffset)
+      resultRow.pointTo(base, readOffset + 4, size)
+      readOffset += 4 + size
+      resultRow
+    }
+  }
+}
+
+/**
+ * A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any
+ * reader has begun reading from the queue.
+ */
+private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueue {
+  private var out = new DataOutputStream(
+    new BufferedOutputStream(new FileOutputStream(file.toString)))
+  private var unreadBytes = 0L
+
+  private var in: DataInputStream = _
+  private val resultRow = new UnsafeRow(fields)
+
+  def add(row: UnsafeRow): Boolean = synchronized {
+    if (out == null) {
+      // Another thread is reading, stop writing this one
+      return false
+    }
+    out.writeInt(row.getSizeInBytes)
+    out.write(row.getBytes)
+    unreadBytes += 4 + row.getSizeInBytes
+    true
+  }
+
+  def remove(): UnsafeRow = synchronized {
+    if (out != null) {
+      out.close()
+      out = null
+      in = new DataInputStream(new BufferedInputStream(new FileInputStream(file.toString)))
+    }
+
+    if (unreadBytes > 0) {
+      val size = in.readInt()
+      val bytes = new Array[Byte](size)
+      in.readFully(bytes)
+      unreadBytes -= 4 + size
+      resultRow.pointTo(bytes, size)
+      resultRow
+    } else {
+      null
+    }
+  }
+
+  def close(): Unit = synchronized {
+    Closeables.close(out, true)
+    out = null
+    Closeables.close(in, true)
+    in = null
+    if (file.exists()) {
+      file.delete()
+    }
+  }
+}
+
+/**
+ * A RowQueue that has a list of RowQueues, which could be in memory or disk.
+ *
+ * HybridRowQueue could be safely appended in one thread, and pulled in another thread in the same
+ * time.
+ */
+private[python] case class HybridRowQueue(
+    memManager: TaskMemoryManager,
+    tempDir: File,
+    numFields: Int)
+  extends MemoryConsumer(memManager) with RowQueue {
+
+  // Each buffer should have at least one row
+  private var queues = new java.util.LinkedList[RowQueue]()
+
+  private var writing: RowQueue = _
+  private var reading: RowQueue = _
+
+  // exposed for testing
+  private[python] def numQueues(): Int = queues.size()
+
+  def spill(size: Long, trigger: MemoryConsumer): Long = {
+    if (trigger == this) {
+      // When it's triggered by itself, it should write upcoming rows into disk instead of copying
+      // the rows already in the queue.
+      return 0L
+    }
+    var released = 0L
+    synchronized {
+      // poll out all the buffers and add them back in the same order to make sure that the rows
+      // are in correct order.
+      val newQueues = new java.util.LinkedList[RowQueue]()
+      while (!queues.isEmpty) {
+        val queue = queues.remove()
+        val newQueue = if (!queues.isEmpty && queue.isInstanceOf[InMemoryRowQueue]) {
+          val diskQueue = createDiskQueue()
+          var row = queue.remove()
+          while (row != null) {
+            diskQueue.add(row)
+            row = queue.remove()
+          }
+          released += queue.asInstanceOf[InMemoryRowQueue].page.size()
+          queue.close()
+          diskQueue
+        } else {
+          queue
+        }
+        newQueues.add(newQueue)
+      }
+      queues = newQueues
+    }
+    released
+  }
+
+  private def createDiskQueue(): RowQueue = {
+    DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields)
+  }
+
+  private def createNewQueue(required: Long): RowQueue = {
+    val page = try {
+      allocatePage(required)
+    } catch {
+      case _: OutOfMemoryError =>
+        null
+    }
+    val buffer = if (page != null) {
+      new InMemoryRowQueue(page, numFields) {
+        override def close(): Unit = {
+          freePage(page)
+        }
+      }
+    } else {
+      createDiskQueue()
+    }
+
+    synchronized {
+      queues.add(buffer)
+    }
+    buffer
+  }
+
+  def add(row: UnsafeRow): Boolean = {
+    if (writing == null || !writing.add(row)) {
+      writing = createNewQueue(4 + row.getSizeInBytes)
+      if (!writing.add(row)) {
+        throw new SparkException(s"failed to push a row into $writing")
+      }
+    }
+    true
+  }
+
+  def remove(): UnsafeRow = {
+    var row: UnsafeRow = null
+    if (reading != null) {
+      row = reading.remove()
+    }
+    if (row == null) {
+      if (reading != null) {
+        reading.close()
+      }
+      synchronized {
+        reading = queues.remove()
+      }
+      assert(reading != null, s"queue should not be empty")
+      row = reading.remove()
+      assert(row != null, s"$reading should have at least one row")
+    }
+    row
+  }
+
+  def close(): Unit = {
+    if (reading != null) {
+      reading.close()
+      reading = null
+    }
+    synchronized {
+      while (!queues.isEmpty) {
+        queues.remove().close()
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2badb58c/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
new file mode 100644
index 0000000..ffda33c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.unsafe.memory.MemoryBlock
+import org.apache.spark.util.Utils
+
+class RowQueueSuite extends SparkFunSuite {
+
+  test("in-memory queue") {
+    val page = MemoryBlock.fromLongArray(new Array[Long](1<<10))
+    val queue = new InMemoryRowQueue(page, 1) {
+      override def close() {}
+    }
+    val row = new UnsafeRow(1)
+    row.pointTo(new Array[Byte](16), 16)
+    val n = page.size() / (4 + row.getSizeInBytes)
+    var i = 0
+    while (i < n) {
+      row.setLong(0, i)
+      assert(queue.add(row), "fail to add")
+      i += 1
+    }
+    assert(!queue.add(row), "should not add more")
+    i = 0
+    while (i < n) {
+      val row = queue.remove()
+      assert(row != null, "fail to poll")
+      assert(row.getLong(0) == i, "does not match")
+      i += 1
+    }
+    assert(queue.remove() == null, "should be empty")
+    queue.close()
+  }
+
+  test("disk queue") {
+    val dir = Utils.createTempDir().getCanonicalFile
+    dir.mkdirs()
+    val queue = DiskRowQueue(new File(dir, "buffer"), 1)
+    val row = new UnsafeRow(1)
+    row.pointTo(new Array[Byte](16), 16)
+    val n = 1000
+    var i = 0
+    while (i < n) {
+      row.setLong(0, i)
+      assert(queue.add(row), "fail to add")
+      i += 1
+    }
+    val first = queue.remove()
+    assert(first != null, "first should not be null")
+    assert(first.getLong(0) == 0, "first should be 0")
+    assert(!queue.add(row), "should not add more")
+    i = 1
+    while (i < n) {
+      val row = queue.remove()
+      assert(row != null, "fail to poll")
+      assert(row.getLong(0) == i, "does not match")
+      i += 1
+    }
+    assert(queue.remove() == null, "should be empty")
+    queue.close()
+  }
+
+  test("hybrid queue") {
+    val mem = new TestMemoryManager(new SparkConf())
+    mem.limit(4<<10)
+    val taskM = new TaskMemoryManager(mem, 0)
+    val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1)
+    val row = new UnsafeRow(1)
+    row.pointTo(new Array[Byte](16), 16)
+    val n = (4<<10) / 16 * 3
+    var i = 0
+    while (i < n) {
+      row.setLong(0, i)
+      assert(queue.add(row), "fail to add")
+      i += 1
+    }
+    assert(queue.numQueues() > 1, "should have more than one queue")
+    queue.spill(1<<20, null)
+    i = 0
+    while (i < n) {
+      val row = queue.remove()
+      assert(row != null, "fail to poll")
+      assert(row.getLong(0) == i, "does not match")
+      i += 1
+    }
+
+    // fill again and spill
+    i = 0
+    while (i < n) {
+      row.setLong(0, i)
+      assert(queue.add(row), "fail to add")
+      i += 1
+    }
+    assert(queue.numQueues() > 1, "should have more than one queue")
+    queue.spill(1<<20, null)
+    assert(queue.numQueues() > 1, "should have more than one queue")
+    i = 0
+    while (i < n) {
+      val row = queue.remove()
+      assert(row != null, "fail to poll")
+      assert(row.getLong(0) == i, "does not match")
+      i += 1
+    }
+    queue.close()
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org