You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2022/11/16 02:54:30 UTC
[spark] branch master updated: [SPARK-40622][SQL][CORE] Remove the limitation that single task result must fit in 2GB
This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c23245d78e2 [SPARK-40622][SQL][CORE] Remove the limitation that single task result must fit in 2GB
c23245d78e2 is described below
commit c23245d78e25497ac6e8848ca400a920fed62144
Author: Ziqi Liu <zi...@databricks.com>
AuthorDate: Tue Nov 15 20:54:20 2022 -0600
[SPARK-40622][SQL][CORE] Remove the limitation that single task result must fit in 2GB
### What changes were proposed in this pull request?
Single task result must fit in 2GB, because we're using byte array or `ByteBuffer`(which is backed by byte array as well), and thus has a limit of 2GB(java array size limit on `byte[]`).
This PR is trying to fix this by replacing byte array with `ChunkedByteBuffer`.
### Why are the changes needed?
To overcome the 2GB limit for single task.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test
Closes #38064 from liuzqt/SPARK-40622.
Authored-by: Ziqi Liu <zi...@databricks.com>
Signed-off-by: Mridul <mridul<at>gmail.com>
---
.../scala/org/apache/spark/executor/Executor.scala | 19 ++++---
.../org/apache/spark/internal/config/package.scala | 2 +
.../org/apache/spark/scheduler/TaskResult.scala | 27 ++++++----
.../apache/spark/scheduler/TaskResultGetter.scala | 14 ++---
.../apache/spark/serializer/SerializerHelper.scala | 54 +++++++++++++++++++
.../main/scala/org/apache/spark/util/Utils.scala | 45 ++++++++++------
.../apache/spark/util/io/ChunkedByteBuffer.scala | 62 ++++++++++++++++++++--
.../apache/spark/io/ChunkedByteBufferSuite.scala | 50 +++++++++++++++++
.../scheduler/SchedulerIntegrationSuite.scala | 3 +-
.../spark/scheduler/TaskResultGetterSuite.scala | 2 +-
.../spark/scheduler/TaskSchedulerImplSuite.scala | 8 +--
.../spark/scheduler/TaskSetManagerSuite.scala | 2 +-
.../KryoSerializerResizableOutputSuite.scala | 16 +++---
project/SparkBuild.scala | 1 +
.../spark/sql/catalyst/expressions/Cast.scala | 6 +--
.../org/apache/spark/sql/execution/SparkPlan.scala | 22 ++++----
.../scala/org/apache/spark/sql/DatasetSuite.scala | 30 ++++++++++-
17 files changed, 289 insertions(+), 74 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index db507bd176b..8d8a4592a3e 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -48,10 +48,10 @@ import org.apache.spark.metrics.source.JVMCPUSource
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.scheduler._
+import org.apache.spark.serializer.SerializerHelper
import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
-import org.apache.spark.util.io.ChunkedByteBuffer
/**
* Spark executor, backed by a threadpool to run tasks.
@@ -172,7 +172,7 @@ private[spark] class Executor(
env.serializerManager.setDefaultClassLoader(replClassLoader)
// Max size of direct result. If task result is bigger than this, we use the block manager
- // to send the result back.
+ // to send the result back. This is guaranteed to be smaller than array bytes limit (2GB)
private val maxDirectResultSize = Math.min(
conf.get(TASK_MAX_DIRECT_RESULT_SIZE),
RpcUtils.maxMessageSizeBytes(conf))
@@ -596,7 +596,7 @@ private[spark] class Executor(
val resultSer = env.serializer.newInstance()
val beforeSerializationNs = System.nanoTime()
- val valueBytes = resultSer.serialize(value)
+ val valueByteBuffer = SerializerHelper.serializeToChunkedBuffer(resultSer, value)
val afterSerializationNs = System.nanoTime()
// Deserialization happens in two parts: first, we deserialize a Task object, which
@@ -659,9 +659,11 @@ private[spark] class Executor(
val accumUpdates = task.collectAccumulatorUpdates()
val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId)
// TODO: do not serialize value twice
- val directResult = new DirectTaskResult(valueBytes, accumUpdates, metricPeaks)
- val serializedDirectResult = ser.serialize(directResult)
- val resultSize = serializedDirectResult.limit()
+ val directResult = new DirectTaskResult(valueByteBuffer, accumUpdates, metricPeaks)
+ // try to estimate a reasonable upper bound of DirectTaskResult serialization
+ val serializedDirectResult = SerializerHelper.serializeToChunkedBuffer(ser, directResult,
+ valueByteBuffer.size + accumUpdates.size * 32 + metricPeaks.length * 8)
+ val resultSize = serializedDirectResult.size
// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
@@ -674,13 +676,14 @@ private[spark] class Executor(
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId,
- new ChunkedByteBuffer(serializedDirectResult.duplicate()),
+ serializedDirectResult,
StorageLevel.MEMORY_AND_DISK_SER)
logInfo(s"Finished $taskName. $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName. $resultSize bytes result sent to driver")
- serializedDirectResult
+ // toByteBuffer is safe here, guarded by maxDirectResultSize
+ serializedDirectResult.toByteBuffer
}
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 64801712c5f..ad899d7dfd6 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -802,6 +802,8 @@ package object config {
ConfigBuilder("spark.task.maxDirectResultSize")
.version("2.0.0")
.bytesConf(ByteUnit.BYTE)
+ .checkValue(_ < ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toLong,
+ "The max direct result size is 2GB")
.createWithDefault(1L << 20)
private[spark] val TASK_MAX_FAILURES =
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 11d969e1aba..e5ab74f544e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -24,20 +24,21 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkEnv
import org.apache.spark.metrics.ExecutorMetricType
-import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.serializer.{SerializerHelper, SerializerInstance}
import org.apache.spark.storage.BlockId
import org.apache.spark.util.{AccumulatorV2, Utils}
+import org.apache.spark.util.io.ChunkedByteBuffer
// Task result. Also contains updates to accumulator variables and executor metric peaks.
private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
-private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
+private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Long)
extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value, accumulator updates and metric peaks. */
private[spark] class DirectTaskResult[T](
- var valueBytes: ByteBuffer,
+ var valueByteBuffer: ChunkedByteBuffer,
var accumUpdates: Seq[AccumulatorV2[_, _]],
var metricPeaks: Array[Long])
extends TaskResult[T] with Externalizable {
@@ -45,12 +46,18 @@ private[spark] class DirectTaskResult[T](
private var valueObjectDeserialized = false
private var valueObject: T = _
- def this() = this(null.asInstanceOf[ByteBuffer], null,
+ def this(
+ valueByteBuffer: ByteBuffer,
+ accumUpdates: Seq[AccumulatorV2[_, _]],
+ metricPeaks: Array[Long]) = {
+ this(new ChunkedByteBuffer(Array(valueByteBuffer)), accumUpdates, metricPeaks)
+ }
+
+ def this() = this(null.asInstanceOf[ChunkedByteBuffer], Seq(),
new Array[Long](ExecutorMetricType.numMetrics))
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- out.writeInt(valueBytes.remaining)
- Utils.writeByteBuffer(valueBytes, out)
+ valueByteBuffer.writeExternal(out)
out.writeInt(accumUpdates.size)
accumUpdates.foreach(out.writeObject)
out.writeInt(metricPeaks.length)
@@ -58,10 +65,8 @@ private[spark] class DirectTaskResult[T](
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
- val blen = in.readInt()
- val byteVal = new Array[Byte](blen)
- in.readFully(byteVal)
- valueBytes = ByteBuffer.wrap(byteVal)
+ valueByteBuffer = new ChunkedByteBuffer()
+ valueByteBuffer.readExternal(in)
val numUpdates = in.readInt
if (numUpdates == 0) {
@@ -100,7 +105,7 @@ private[spark] class DirectTaskResult[T](
// This should not run when holding a lock because it may cost dozens of seconds for a large
// value
val ser = if (resultSer == null) SparkEnv.get.serializer.newInstance() else resultSer
- valueObject = ser.deserialize(valueBytes)
+ valueObject = SerializerHelper.deserializeFromChunkedBuffer(ser, valueByteBuffer)
valueObjectDeserialized = true
valueObject
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index cfc1f79fab2..a4f29395095 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.internal.Logging
-import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.serializer.{SerializerHelper, SerializerInstance}
import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils}
/**
@@ -63,7 +63,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
case directResult: DirectTaskResult[_] =>
- if (!taskSetManager.canFetchMoreResults(directResult.valueBytes.limit())) {
+ if (!taskSetManager.canFetchMoreResults(directResult.valueByteBuffer.size)) {
// kill the task so that it will not become zombie task
scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
"Tasks result size has exceeded maxResultSize"))
@@ -73,7 +73,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
// We should call it here, so that when it's called again in
// "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value.
directResult.value(taskResultSerializer.get())
- (directResult, serializedData.limit())
+ (directResult, serializedData.limit().toLong)
case IndirectTaskResult(blockId, size) =>
if (!taskSetManager.canFetchMoreResults(size)) {
// dropped by executor if size is larger than maxResultSize
@@ -94,8 +94,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
return
}
- val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
- serializedTaskResult.get.toByteBuffer)
+ val deserializedResult = SerializerHelper
+ .deserializeFromChunkedBuffer[DirectTaskResult[_]](
+ serializer.get(),
+ serializedTaskResult.get)
// force deserialization of referenced value
deserializedResult.value(taskResultSerializer.get())
sparkEnv.blockManager.master.removeBlock(blockId)
@@ -109,7 +111,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
val acc = a.asInstanceOf[LongAccumulator]
assert(acc.sum == 0L, "task result size should not have been set on the executors")
- acc.setValue(size.toLong)
+ acc.setValue(size)
acc
} else {
a
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerHelper.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerHelper.scala
new file mode 100644
index 00000000000..2cff87990a4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerHelper.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
+
+private[spark] object SerializerHelper extends Logging {
+
+ /**
+ *
+ * @param serializerInstance instance of SerializerInstance
+ * @param objectToSerialize the object to serialize, of type `T`
+ * @param estimatedSize estimated size of `t`, used as a hint to choose proper chunk size
+ */
+ def serializeToChunkedBuffer[T: ClassTag](
+ serializerInstance: SerializerInstance,
+ objectToSerialize: T,
+ estimatedSize: Long = -1): ChunkedByteBuffer = {
+ val chunkSize = ChunkedByteBuffer.estimateBufferChunkSize(estimatedSize)
+ val cbbos = new ChunkedByteBufferOutputStream(chunkSize, ByteBuffer.allocate)
+ val out = serializerInstance.serializeStream(cbbos)
+ out.writeObject(objectToSerialize)
+ out.close()
+ cbbos.close()
+ cbbos.toChunkedByteBuffer
+ }
+
+ def deserializeFromChunkedBuffer[T: ClassTag](
+ serializerInstance: SerializerInstance,
+ bytes: ChunkedByteBuffer): T = {
+ val in = serializerInstance.deserializeStream(bytes.toInputStream())
+ in.readObject()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index f963727e79f..70477a5c9c0 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -111,6 +111,12 @@ private[spark] object Utils extends Logging {
private val PATTERN_FOR_COMMAND_LINE_ARG = "-D(.+?)=(.+)".r
+ private val COPY_BUFFER_LEN = 1024
+
+ private val copyBuffer = ThreadLocal.withInitial[Array[Byte]](() => {
+ new Array[Byte](COPY_BUFFER_LEN)
+ })
+
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -237,34 +243,39 @@ private[spark] object Utils extends Logging {
}
}
- /**
- * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
- */
- def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
+ private def writeByteBufferImpl(bb: ByteBuffer, writer: (Array[Byte], Int, Int) => Unit): Unit = {
if (bb.hasArray) {
- out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+ // Avoid extra copy if the bytebuffer is backed by bytes array
+ writer(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
} else {
+ // Fallback to copy approach
+ val buffer = {
+ // reuse the copy buffer from thread local
+ copyBuffer.get()
+ }
val originalPosition = bb.position()
- val bbval = new Array[Byte](bb.remaining())
- bb.get(bbval)
- out.write(bbval)
+ var bytesToCopy = Math.min(bb.remaining(), COPY_BUFFER_LEN)
+ while (bytesToCopy > 0) {
+ bb.get(buffer, 0, bytesToCopy)
+ writer(buffer, 0, bytesToCopy)
+ bytesToCopy = Math.min(bb.remaining(), COPY_BUFFER_LEN)
+ }
bb.position(originalPosition)
}
}
+ /**
+ * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
+ */
+ def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
+ writeByteBufferImpl(bb, out.write)
+ }
+
/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]]
*/
def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = {
- if (bb.hasArray) {
- out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
- } else {
- val originalPosition = bb.position()
- val bbval = new Array[Byte](bb.remaining())
- bb.get(bbval)
- out.write(bbval)
- bb.position(originalPosition)
- }
+ writeByteBufferImpl(bb, out.write)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index 8635f1a3d70..73e4e72cc5b 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -17,7 +17,7 @@
package org.apache.spark.util.io
-import java.io.{File, FileInputStream, InputStream}
+import java.io.{Externalizable, File, FileInputStream, InputStream, ObjectInput, ObjectOutput}
import java.nio.ByteBuffer
import java.nio.channels.WritableByteChannel
@@ -42,8 +42,9 @@ import org.apache.spark.util.Utils
* buffers may also be used elsewhere then the caller is responsible for copying
* them as needed.
*/
-private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
+private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) extends Externalizable {
require(chunks != null, "chunks must not be null")
+ require(!chunks.contains(null), "chunks must not contain null")
require(chunks.forall(_.position() == 0), "chunks' positions must be 0")
// Chunk size in bytes
@@ -54,9 +55,16 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
private[this] var disposed: Boolean = false
/**
- * This size of this buffer, in bytes.
+ * This size of this buffer, in bytes. Using var here for serialization purpose (need to set a
+ * object after default construction)
*/
- val size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum
+ private var _size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum
+
+ def size: Long = _size
+
+ def this() = {
+ this(Array.empty[ByteBuffer])
+ }
def this(byteBuffer: ByteBuffer) = {
this(Array(byteBuffer))
@@ -84,6 +92,38 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
}
}
+ /**
+ * Writes to the provided ObjectOutput with zero copy if possible.
+ */
+ override def writeExternal(out: ObjectOutput): Unit = {
+ // We want to keep the chunks layout
+ out.writeInt(chunks.length)
+ val chunksCopy = getChunks()
+ chunksCopy.foreach(buffer => out.writeInt(buffer.limit()))
+ chunksCopy.foreach(Utils.writeByteBuffer(_, out))
+ }
+
+ override def readExternal(in: ObjectInput): Unit = {
+ val chunksNum = in.readInt()
+ val indices = 0 until chunksNum
+ val chunksSize = indices.map(_ => in.readInt())
+ val chunks = new Array[ByteBuffer](chunksNum)
+
+ // We deserialize all chunks into on-heap buffer by default. If we have use case in the future
+ // where we want to preserve the on-heap/off-heap nature of chunks, then we need to record the
+ // `isDirect` property of each chunk during serialization
+ indices.foreach { i =>
+ val chunkSize = chunksSize(i)
+ chunks(i) = {
+ val arr = new Array[Byte](chunkSize)
+ in.readFully(arr, 0, chunkSize)
+ ByteBuffer.wrap(arr)
+ }
+ }
+ this.chunks = chunks
+ this._size = chunks.map(_.limit().toLong).sum
+ }
+
/**
* Wrap this in a custom "FileRegion" which allows us to transfer over 2 GB.
*/
@@ -171,6 +211,8 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
}
private[spark] object ChunkedByteBuffer {
+ private val CHUNK_BUFFER_SIZE: Int = 1024 * 1024
+ private val MINIMUM_CHUNK_BUFFER_SIZE: Int = 1024
def fromManagedBuffer(data: ManagedBuffer): ChunkedByteBuffer = {
data match {
@@ -207,6 +249,18 @@ private[spark] object ChunkedByteBuffer {
}
out.toChunkedByteBuffer
}
+
+ /**
+ * Try to estimate appropriate chunk size so that it's not too large (waste memory) or too
+ * small (too many segments)
+ */
+ def estimateBufferChunkSize(estimatedSize: Long = -1): Int = {
+ if (estimatedSize < 0) {
+ CHUNK_BUFFER_SIZE
+ } else {
+ Math.max(Math.min(estimatedSize, CHUNK_BUFFER_SIZE).toInt, MINIMUM_CHUNK_BUFFER_SIZE)
+ }
+ }
}
/**
diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
index 083c5e696b7..68b181de292 100644
--- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.io
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
import java.nio.ByteBuffer
import com.google.common.io.ByteStreams
@@ -28,6 +29,18 @@ import org.apache.spark.util.io.ChunkedByteBuffer
class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext {
+ /**
+ * compare two ChunkedByteBuffer:
+ * - chunks nums equal
+ * - each chunk's content
+ */
+ def assertBufferEqual(buffer1: ChunkedByteBuffer, buffer2: ChunkedByteBuffer): Unit = {
+ assert(buffer1.chunks.length == buffer2.chunks.length)
+ assert(buffer1.chunks.zip(buffer2.chunks).forall {
+ case (chunk1, chunk2) => chunk1 == chunk2
+ })
+ }
+
test("no chunks") {
val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer])
assert(emptyChunkedByteBuffer.size === 0)
@@ -69,6 +82,43 @@ class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext {
}
}
+ test("Externalizable: writeExternal() and readExternal()") {
+ // intentionally generate arrays of different len, in order to verify the chunks layout
+ // is preserved after ser/deser
+ val byteArrays = (1 to 15).map(i => (0 until i).map(_.toByte).toArray)
+ val chunkedByteBuffer = new ChunkedByteBuffer(byteArrays.map(ByteBuffer.wrap).toArray)
+ val baos = new ByteArrayOutputStream()
+ val objOut = new ObjectOutputStream(baos)
+ chunkedByteBuffer.writeExternal(objOut)
+ objOut.close()
+ assert(chunkedByteBuffer.chunks.forall(_.position() == 0))
+
+ val chunkedByteBuffer2 = {
+ val tmp = new ChunkedByteBuffer
+ tmp.readExternal(new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray)))
+ tmp
+ }
+ assertBufferEqual(chunkedByteBuffer, chunkedByteBuffer2)
+ }
+
+ test(
+ "Externalizable: writeExternal() and readExternal() should handle off-heap buffer properly") {
+ val chunkedByteBuffer = new ChunkedByteBuffer(
+ (0 until 10).map(_ => ByteBuffer.allocateDirect(10)).toArray)
+ val baos = new ByteArrayOutputStream()
+ val objOut = new ObjectOutputStream(baos)
+ chunkedByteBuffer.writeExternal(objOut)
+ objOut.close()
+
+ val chunkedByteBuffer2 = {
+ val tmp = new ChunkedByteBuffer
+ tmp.readExternal(new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray)))
+ tmp
+ }
+
+ assertBufferEqual(chunkedByteBuffer, chunkedByteBuffer2)
+ }
+
test("toArray()") {
val empty = ByteBuffer.wrap(Array.empty[Byte])
val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
index 9ed26e71256..dac675fd738 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
@@ -321,7 +321,8 @@ private[spark] abstract class MockBackend(
def taskSuccess(task: TaskDescription, result: Any): Unit = {
val ser = env.serializer.newInstance()
val resultBytes = ser.serialize(result)
- val directResult = new DirectTaskResult(resultBytes, Seq(), Array()) // no accumulator updates
+ // no accumulator updates
+ val directResult = new DirectTaskResult(resultBytes, Seq(), Array[Long]())
taskUpdate(task, TaskState.FINISHED, directResult)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index 1583d3b96ee..1f61fab3e07 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -153,7 +153,7 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
override def canFetchMoreResults(size: Long): Boolean = false
}
val indirectTaskResult = IndirectTaskResult(TaskResultBlockId(0), 0)
- val directTaskResult = new DirectTaskResult(ByteBuffer.allocate(0), Nil, Array())
+ val directTaskResult = new DirectTaskResult(ByteBuffer.allocate(0), Nil, Array[Long]())
val ser = sc.env.closureSerializer.newInstance()
val serializedIndirect = ser.serialize(indirectTaskResult)
val serializedDirect = ser.serialize(directTaskResult)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 4e9e9755e85..b81f85bd1d7 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -761,11 +761,13 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext
}
// End the other task of the taskset, doesn't matter whether it succeeds or fails.
val otherTask = tasks(1)
- val result = new DirectTaskResult[Int](valueSer.serialize(otherTask.taskId), Seq(), Array())
+ val result = new DirectTaskResult[Int](valueSer.serialize(otherTask.taskId), Seq(),
+ Array[Long]())
tsm.handleSuccessfulTask(otherTask.taskId, result)
} else {
tasks.foreach { task =>
- val result = new DirectTaskResult[Int](valueSer.serialize(task.taskId), Seq(), Array())
+ val result = new DirectTaskResult[Int](valueSer.serialize(task.taskId), Seq(),
+ Array[Long]())
tsm.handleSuccessfulTask(task.taskId, result)
}
}
@@ -2131,7 +2133,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext
assert(2 === taskDescriptions.length)
val ser = sc.env.serializer.newInstance()
- val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(), Array.empty)
+ val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(), Array.empty[Long])
val resultBytes = ser.serialize(directResult)
val busyTask = new Runnable {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 32a43b093ee..2dc7f0d0dfa 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -882,7 +882,7 @@ class TaskSetManagerSuite
assert(manager.runningTasks === 2)
assert(manager.isZombie === false)
- val directTaskResult = new DirectTaskResult[String](null, Seq(), Array()) {
+ val directTaskResult = new DirectTaskResult[String]() {
override def value(resultSer: SerializerInstance): String = ""
}
// Complete one copy of the task, which should result in the task set manager
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala
index 25f0b19c980..41c1131a280 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala
@@ -18,8 +18,6 @@
package org.apache.spark.serializer
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.LocalSparkContext._
-import org.apache.spark.SparkContext
import org.apache.spark.SparkException
import org.apache.spark.internal.config._
import org.apache.spark.internal.config.Kryo._
@@ -34,9 +32,10 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite {
conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
conf.set(KRYO_SERIALIZER_BUFFER_SIZE.key, "1m")
conf.set(KRYO_SERIALIZER_MAX_BUFFER_SIZE.key, "1m")
- withSpark(new SparkContext("local", "test", conf)) { sc =>
- intercept[SparkException](sc.parallelize(x).collect())
- }
+
+ val ser = new KryoSerializer(conf)
+ val serInstance = ser.newInstance()
+ intercept[SparkException](serInstance.serialize(x))
}
test("kryo with resizable output buffer should succeed on large array") {
@@ -44,8 +43,9 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite {
conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
conf.set(KRYO_SERIALIZER_BUFFER_SIZE.key, "1m")
conf.set(KRYO_SERIALIZER_MAX_BUFFER_SIZE.key, "2m")
- withSpark(new SparkContext("local", "test", conf)) { sc =>
- assert(sc.parallelize(x).collect() === x)
- }
+
+ val ser = new KryoSerializer(conf)
+ val serInstance = ser.newInstance()
+ serInstance.serialize(x)
}
}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 18667d1efea..a63f52e5430 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -1175,6 +1175,7 @@ object Unidoc {
!f.getCanonicalPath.contains("org/apache/spark/unsafe/types/CalendarInterval")))
.map(_.filterNot(_.getCanonicalPath.contains("python")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection")))
+ .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/io")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect")))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 549bc70bac7..a302298d99c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -46,9 +46,9 @@ object Cast extends QueryErrorsBase {
* As per section 6.13 "cast specification" in "Information technology — Database languages " +
* "- SQL — Part 2: Foundation (SQL/Foundation)":
* If the <cast operand> is a <value expression>, then the valid combinations of TD and SD
- * in a <cast specification> are given by the following table. “Y” indicates that the
- * combination is syntactically valid without restriction; “M” indicates that the combination
- * is valid subject to other Syntax Rules in this Sub- clause being satisfied; and “N” indicates
+ * in a <cast specification> are given by the following table. "Y" indicates that the
+ * combination is syntactically valid without restriction; "M" indicates that the combination
+ * is valid subject to other Syntax Rules in this Sub- clause being satisfied; and "N" indicates
* that the combination is not valid:
* SD TD
* EN AN C D T TS YM DT BO UDT B RT CT RW
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index a56732fdc12..4aca67a17cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.execution
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.io.{DataInputStream, DataOutputStream}
+import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
@@ -38,6 +39,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.NextIterator
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
object SparkPlan {
/** The original [[LogicalPlan]] from which this [[SparkPlan]] is converted. */
@@ -336,13 +338,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* compressed.
*/
private def getByteArrayRdd(
- n: Int = -1, takeFromEnd: Boolean = false): RDD[(Long, Array[Byte])] = {
+ n: Int = -1, takeFromEnd: Boolean = false): RDD[(Long, ChunkedByteBuffer)] = {
execute().mapPartitionsInternal { iter =>
var count = 0
val buffer = new Array[Byte](4 << 10) // 4K
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
- val bos = new ByteArrayOutputStream()
- val out = new DataOutputStream(codec.compressedOutputStream(bos))
+ val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
+ val out = new DataOutputStream(codec.compressedOutputStream(cbbos))
if (takeFromEnd && n > 0) {
// To collect n from the last, we should anyway read everything with keeping the n.
@@ -371,19 +373,19 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
out.writeInt(-1)
out.flush()
out.close()
- Iterator((count, bos.toByteArray))
+ Iterator((count, cbbos.toChunkedByteBuffer))
}
}
/**
* Decodes the byte arrays back to UnsafeRows and put them into buffer.
*/
- private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = {
+ private def decodeUnsafeRows(bytes: ChunkedByteBuffer): Iterator[InternalRow] = {
val nFields = schema.length
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
- val bis = new ByteArrayInputStream(bytes)
- val ins = new DataInputStream(codec.compressedInputStream(bis))
+ val cbbis = bytes.toInputStream()
+ val ins = new DataInputStream(codec.compressedInputStream(cbbis))
new NextIterator[InternalRow] {
private var sizeOfNextRow = ins.readInt()
@@ -503,8 +505,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
parts
}
val sc = sparkContext
- val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) =>
- if (it.hasNext) it.next() else (0L, Array.emptyByteArray), partsToScan)
+ val res = sc.runJob(childRDD, (it: Iterator[(Long, ChunkedByteBuffer)]) =>
+ if (it.hasNext) it.next() else (0L, new ChunkedByteBuffer()), partsToScan)
var i = 0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 8f5740e65ed..370e5ca546b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -20,13 +20,16 @@ package org.apache.spark.sql
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}
+import scala.util.Random
+
import org.apache.hadoop.fs.{Path, PathFilter}
import org.scalatest.Assertions._
import org.scalatest.exceptions.TestFailedException
import org.scalatest.prop.TableDrivenPropertyChecks._
-import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.{SparkConf, SparkException, TaskContext}
import org.apache.spark.TestUtils.withListener
+import org.apache.spark.internal.config.MAX_RESULT_SIZE
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample}
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
@@ -2228,6 +2231,31 @@ class DatasetSuite extends QueryTest
}
}
+class DatasetLargeResultCollectingSuite extends QueryTest
+ with SharedSparkSession {
+
+ override protected def sparkConf: SparkConf = super.sparkConf.set(MAX_RESULT_SIZE.key, "4g")
+ test("collect data with single partition larger than 2GB bytes array limit") {
+ // This test requires large memory and leads to OOM in Github Action so we skip it. Developer
+ // should verify it in local build.
+ assume(!sys.env.contains("GITHUB_ACTIONS"))
+ import org.apache.spark.sql.functions.udf
+
+ val genData = udf((id: Long, bytesSize: Int) => {
+ val rand = new Random(id)
+ val arr = new Array[Byte](bytesSize)
+ rand.nextBytes(arr)
+ arr
+ })
+
+ spark.udf.register("genData", genData.asNondeterministic())
+ // create data of size >2GB in single partition, which exceeds the byte array limit
+ // random gen to make sure it's poorly compressed
+ val df = spark.range(0, 2100, 1, 1).selectExpr("id", s"genData(id, 1000000) as data")
+ val res = df.queryExecution.executedPlan.executeCollect()
+ }
+}
+
case class Bar(a: Int)
object AssertExecutionId {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org