You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2022/11/16 02:54:30 UTC

[spark] branch master updated: [SPARK-40622][SQL][CORE] Remove the limitation that single task result must fit in 2GB

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c23245d78e2 [SPARK-40622][SQL][CORE] Remove the limitation that single task result must fit in 2GB
c23245d78e2 is described below

commit c23245d78e25497ac6e8848ca400a920fed62144
Author: Ziqi Liu <zi...@databricks.com>
AuthorDate: Tue Nov 15 20:54:20 2022 -0600

    [SPARK-40622][SQL][CORE] Remove the limitation that single task result must fit in 2GB
    
    ### What changes were proposed in this pull request?
    
    Single task result must fit in 2GB, because we're using byte array or `ByteBuffer`(which is backed by byte array as well), and thus has a limit of 2GB(java array size limit on `byte[]`).
    
    This PR is trying to fix this by replacing byte array with `ChunkedByteBuffer`.
    
    ### Why are the changes needed?
    
    To overcome the 2GB limit for single task.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Unit test
    
    Closes #38064 from liuzqt/SPARK-40622.
    
    Authored-by: Ziqi Liu <zi...@databricks.com>
    Signed-off-by: Mridul <mridul<at>gmail.com>
---
 .../scala/org/apache/spark/executor/Executor.scala | 19 ++++---
 .../org/apache/spark/internal/config/package.scala |  2 +
 .../org/apache/spark/scheduler/TaskResult.scala    | 27 ++++++----
 .../apache/spark/scheduler/TaskResultGetter.scala  | 14 ++---
 .../apache/spark/serializer/SerializerHelper.scala | 54 +++++++++++++++++++
 .../main/scala/org/apache/spark/util/Utils.scala   | 45 ++++++++++------
 .../apache/spark/util/io/ChunkedByteBuffer.scala   | 62 ++++++++++++++++++++--
 .../apache/spark/io/ChunkedByteBufferSuite.scala   | 50 +++++++++++++++++
 .../scheduler/SchedulerIntegrationSuite.scala      |  3 +-
 .../spark/scheduler/TaskResultGetterSuite.scala    |  2 +-
 .../spark/scheduler/TaskSchedulerImplSuite.scala   |  8 +--
 .../spark/scheduler/TaskSetManagerSuite.scala      |  2 +-
 .../KryoSerializerResizableOutputSuite.scala       | 16 +++---
 project/SparkBuild.scala                           |  1 +
 .../spark/sql/catalyst/expressions/Cast.scala      |  6 +--
 .../org/apache/spark/sql/execution/SparkPlan.scala | 22 ++++----
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 30 ++++++++++-
 17 files changed, 289 insertions(+), 74 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index db507bd176b..8d8a4592a3e 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -48,10 +48,10 @@ import org.apache.spark.metrics.source.JVMCPUSource
 import org.apache.spark.resource.ResourceInformation
 import org.apache.spark.rpc.RpcTimeout
 import org.apache.spark.scheduler._
+import org.apache.spark.serializer.SerializerHelper
 import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
 import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
 import org.apache.spark.util._
-import org.apache.spark.util.io.ChunkedByteBuffer
 
 /**
  * Spark executor, backed by a threadpool to run tasks.
@@ -172,7 +172,7 @@ private[spark] class Executor(
   env.serializerManager.setDefaultClassLoader(replClassLoader)
 
   // Max size of direct result. If task result is bigger than this, we use the block manager
-  // to send the result back.
+  // to send the result back. This is guaranteed to be smaller than array bytes limit (2GB)
   private val maxDirectResultSize = Math.min(
     conf.get(TASK_MAX_DIRECT_RESULT_SIZE),
     RpcUtils.maxMessageSizeBytes(conf))
@@ -596,7 +596,7 @@ private[spark] class Executor(
 
         val resultSer = env.serializer.newInstance()
         val beforeSerializationNs = System.nanoTime()
-        val valueBytes = resultSer.serialize(value)
+        val valueByteBuffer = SerializerHelper.serializeToChunkedBuffer(resultSer, value)
         val afterSerializationNs = System.nanoTime()
 
         // Deserialization happens in two parts: first, we deserialize a Task object, which
@@ -659,9 +659,11 @@ private[spark] class Executor(
         val accumUpdates = task.collectAccumulatorUpdates()
         val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId)
         // TODO: do not serialize value twice
-        val directResult = new DirectTaskResult(valueBytes, accumUpdates, metricPeaks)
-        val serializedDirectResult = ser.serialize(directResult)
-        val resultSize = serializedDirectResult.limit()
+        val directResult = new DirectTaskResult(valueByteBuffer, accumUpdates, metricPeaks)
+        // try to estimate a reasonable upper bound of DirectTaskResult serialization
+        val serializedDirectResult = SerializerHelper.serializeToChunkedBuffer(ser, directResult,
+          valueByteBuffer.size + accumUpdates.size * 32 + metricPeaks.length * 8)
+        val resultSize = serializedDirectResult.size
 
         // directSend = sending directly back to the driver
         val serializedResult: ByteBuffer = {
@@ -674,13 +676,14 @@ private[spark] class Executor(
             val blockId = TaskResultBlockId(taskId)
             env.blockManager.putBytes(
               blockId,
-              new ChunkedByteBuffer(serializedDirectResult.duplicate()),
+              serializedDirectResult,
               StorageLevel.MEMORY_AND_DISK_SER)
             logInfo(s"Finished $taskName. $resultSize bytes result sent via BlockManager)")
             ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
           } else {
             logInfo(s"Finished $taskName. $resultSize bytes result sent to driver")
-            serializedDirectResult
+            // toByteBuffer is safe here, guarded by maxDirectResultSize
+            serializedDirectResult.toByteBuffer
           }
         }
 
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 64801712c5f..ad899d7dfd6 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -802,6 +802,8 @@ package object config {
     ConfigBuilder("spark.task.maxDirectResultSize")
       .version("2.0.0")
       .bytesConf(ByteUnit.BYTE)
+      .checkValue(_ < ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toLong,
+        "The max direct result size is 2GB")
       .createWithDefault(1L << 20)
 
   private[spark] val TASK_MAX_FAILURES =
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 11d969e1aba..e5ab74f544e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -24,20 +24,21 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.SparkEnv
 import org.apache.spark.metrics.ExecutorMetricType
-import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.serializer.{SerializerHelper, SerializerInstance}
 import org.apache.spark.storage.BlockId
 import org.apache.spark.util.{AccumulatorV2, Utils}
+import org.apache.spark.util.io.ChunkedByteBuffer
 
 // Task result. Also contains updates to accumulator variables and executor metric peaks.
 private[spark] sealed trait TaskResult[T]
 
 /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
-private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
+private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Long)
   extends TaskResult[T] with Serializable
 
 /** A TaskResult that contains the task's return value, accumulator updates and metric peaks. */
 private[spark] class DirectTaskResult[T](
-    var valueBytes: ByteBuffer,
+    var valueByteBuffer: ChunkedByteBuffer,
     var accumUpdates: Seq[AccumulatorV2[_, _]],
     var metricPeaks: Array[Long])
   extends TaskResult[T] with Externalizable {
@@ -45,12 +46,18 @@ private[spark] class DirectTaskResult[T](
   private var valueObjectDeserialized = false
   private var valueObject: T = _
 
-  def this() = this(null.asInstanceOf[ByteBuffer], null,
+  def this(
+    valueByteBuffer: ByteBuffer,
+    accumUpdates: Seq[AccumulatorV2[_, _]],
+    metricPeaks: Array[Long]) = {
+    this(new ChunkedByteBuffer(Array(valueByteBuffer)), accumUpdates, metricPeaks)
+  }
+
+  def this() = this(null.asInstanceOf[ChunkedByteBuffer], Seq(),
     new Array[Long](ExecutorMetricType.numMetrics))
 
   override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
-    out.writeInt(valueBytes.remaining)
-    Utils.writeByteBuffer(valueBytes, out)
+    valueByteBuffer.writeExternal(out)
     out.writeInt(accumUpdates.size)
     accumUpdates.foreach(out.writeObject)
     out.writeInt(metricPeaks.length)
@@ -58,10 +65,8 @@ private[spark] class DirectTaskResult[T](
   }
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
-    val blen = in.readInt()
-    val byteVal = new Array[Byte](blen)
-    in.readFully(byteVal)
-    valueBytes = ByteBuffer.wrap(byteVal)
+    valueByteBuffer = new ChunkedByteBuffer()
+    valueByteBuffer.readExternal(in)
 
     val numUpdates = in.readInt
     if (numUpdates == 0) {
@@ -100,7 +105,7 @@ private[spark] class DirectTaskResult[T](
       // This should not run when holding a lock because it may cost dozens of seconds for a large
       // value
       val ser = if (resultSer == null) SparkEnv.get.serializer.newInstance() else resultSer
-      valueObject = ser.deserialize(valueBytes)
+      valueObject = SerializerHelper.deserializeFromChunkedBuffer(ser, valueByteBuffer)
       valueObjectDeserialized = true
       valueObject
     }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index cfc1f79fab2..a4f29395095 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
 import org.apache.spark._
 import org.apache.spark.TaskState.TaskState
 import org.apache.spark.internal.Logging
-import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.serializer.{SerializerHelper, SerializerInstance}
 import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils}
 
 /**
@@ -63,7 +63,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
         try {
           val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
             case directResult: DirectTaskResult[_] =>
-              if (!taskSetManager.canFetchMoreResults(directResult.valueBytes.limit())) {
+              if (!taskSetManager.canFetchMoreResults(directResult.valueByteBuffer.size)) {
                 // kill the task so that it will not become zombie task
                 scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
                   "Tasks result size has exceeded maxResultSize"))
@@ -73,7 +73,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
               // We should call it here, so that when it's called again in
               // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value.
               directResult.value(taskResultSerializer.get())
-              (directResult, serializedData.limit())
+              (directResult, serializedData.limit().toLong)
             case IndirectTaskResult(blockId, size) =>
               if (!taskSetManager.canFetchMoreResults(size)) {
                 // dropped by executor if size is larger than maxResultSize
@@ -94,8 +94,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
                   taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
                 return
               }
-              val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
-                serializedTaskResult.get.toByteBuffer)
+              val deserializedResult = SerializerHelper
+                .deserializeFromChunkedBuffer[DirectTaskResult[_]](
+                  serializer.get(),
+                  serializedTaskResult.get)
               // force deserialization of referenced value
               deserializedResult.value(taskResultSerializer.get())
               sparkEnv.blockManager.master.removeBlock(blockId)
@@ -109,7 +111,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
             if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
               val acc = a.asInstanceOf[LongAccumulator]
               assert(acc.sum == 0L, "task result size should not have been set on the executors")
-              acc.setValue(size.toLong)
+              acc.setValue(size)
               acc
             } else {
               a
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerHelper.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerHelper.scala
new file mode 100644
index 00000000000..2cff87990a4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerHelper.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
+
+private[spark] object SerializerHelper extends Logging {
+
+  /**
+   *
+   * @param serializerInstance instance of SerializerInstance
+   * @param objectToSerialize the object to serialize, of type `T`
+   * @param estimatedSize estimated size of `t`, used as a hint to choose proper chunk size
+   */
+  def serializeToChunkedBuffer[T: ClassTag](
+      serializerInstance: SerializerInstance,
+      objectToSerialize: T,
+      estimatedSize: Long = -1): ChunkedByteBuffer = {
+    val chunkSize = ChunkedByteBuffer.estimateBufferChunkSize(estimatedSize)
+    val cbbos = new ChunkedByteBufferOutputStream(chunkSize, ByteBuffer.allocate)
+    val out = serializerInstance.serializeStream(cbbos)
+    out.writeObject(objectToSerialize)
+    out.close()
+    cbbos.close()
+    cbbos.toChunkedByteBuffer
+  }
+
+  def deserializeFromChunkedBuffer[T: ClassTag](
+      serializerInstance: SerializerInstance,
+      bytes: ChunkedByteBuffer): T = {
+    val in = serializerInstance.deserializeStream(bytes.toInputStream())
+    in.readObject()
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index f963727e79f..70477a5c9c0 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -111,6 +111,12 @@ private[spark] object Utils extends Logging {
 
   private val PATTERN_FOR_COMMAND_LINE_ARG = "-D(.+?)=(.+)".r
 
+  private val COPY_BUFFER_LEN = 1024
+
+  private val copyBuffer = ThreadLocal.withInitial[Array[Byte]](() => {
+    new Array[Byte](COPY_BUFFER_LEN)
+  })
+
   /** Serialize an object using Java serialization */
   def serialize[T](o: T): Array[Byte] = {
     val bos = new ByteArrayOutputStream()
@@ -237,34 +243,39 @@ private[spark] object Utils extends Logging {
     }
   }
 
-  /**
-   * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
-   */
-  def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
+  private def writeByteBufferImpl(bb: ByteBuffer, writer: (Array[Byte], Int, Int) => Unit): Unit = {
     if (bb.hasArray) {
-      out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+      // Avoid extra copy if the bytebuffer is backed by bytes array
+      writer(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
     } else {
+      // Fallback to copy approach
+      val buffer = {
+        // reuse the copy buffer from thread local
+        copyBuffer.get()
+      }
       val originalPosition = bb.position()
-      val bbval = new Array[Byte](bb.remaining())
-      bb.get(bbval)
-      out.write(bbval)
+      var bytesToCopy = Math.min(bb.remaining(), COPY_BUFFER_LEN)
+      while (bytesToCopy > 0) {
+        bb.get(buffer, 0, bytesToCopy)
+        writer(buffer, 0, bytesToCopy)
+        bytesToCopy = Math.min(bb.remaining(), COPY_BUFFER_LEN)
+      }
       bb.position(originalPosition)
     }
   }
 
+  /**
+   * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
+   */
+  def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
+    writeByteBufferImpl(bb, out.write)
+  }
+
   /**
    * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]]
    */
   def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = {
-    if (bb.hasArray) {
-      out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
-    } else {
-      val originalPosition = bb.position()
-      val bbval = new Array[Byte](bb.remaining())
-      bb.get(bbval)
-      out.write(bbval)
-      bb.position(originalPosition)
-    }
+    writeByteBufferImpl(bb, out.write)
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index 8635f1a3d70..73e4e72cc5b 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.util.io
 
-import java.io.{File, FileInputStream, InputStream}
+import java.io.{Externalizable, File, FileInputStream, InputStream, ObjectInput, ObjectOutput}
 import java.nio.ByteBuffer
 import java.nio.channels.WritableByteChannel
 
@@ -42,8 +42,9 @@ import org.apache.spark.util.Utils
  *               buffers may also be used elsewhere then the caller is responsible for copying
  *               them as needed.
  */
-private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
+private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) extends Externalizable {
   require(chunks != null, "chunks must not be null")
+  require(!chunks.contains(null), "chunks must not contain null")
   require(chunks.forall(_.position() == 0), "chunks' positions must be 0")
 
   // Chunk size in bytes
@@ -54,9 +55,16 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
   private[this] var disposed: Boolean = false
 
   /**
-   * This size of this buffer, in bytes.
+   * This size of this buffer, in bytes. Using var here for serialization purpose (need to set a
+   * object after default construction)
    */
-  val size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum
+  private var _size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum
+
+  def size: Long = _size
+
+  def this() = {
+    this(Array.empty[ByteBuffer])
+  }
 
   def this(byteBuffer: ByteBuffer) = {
     this(Array(byteBuffer))
@@ -84,6 +92,38 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
     }
   }
 
+  /**
+   * Writes to the provided ObjectOutput with zero copy if possible.
+   */
+  override def writeExternal(out: ObjectOutput): Unit = {
+    // We want to keep the chunks layout
+    out.writeInt(chunks.length)
+    val chunksCopy = getChunks()
+    chunksCopy.foreach(buffer => out.writeInt(buffer.limit()))
+    chunksCopy.foreach(Utils.writeByteBuffer(_, out))
+  }
+
+  override def readExternal(in: ObjectInput): Unit = {
+    val chunksNum = in.readInt()
+    val indices = 0 until chunksNum
+    val chunksSize = indices.map(_ => in.readInt())
+    val chunks = new Array[ByteBuffer](chunksNum)
+
+    // We deserialize all chunks into on-heap buffer by default. If we have use case in the future
+    // where we want to preserve the on-heap/off-heap nature of chunks, then we need to record the
+    // `isDirect` property of each chunk during serialization
+    indices.foreach { i =>
+      val chunkSize = chunksSize(i)
+      chunks(i) = {
+        val arr = new Array[Byte](chunkSize)
+        in.readFully(arr, 0, chunkSize)
+        ByteBuffer.wrap(arr)
+      }
+    }
+    this.chunks = chunks
+    this._size = chunks.map(_.limit().toLong).sum
+  }
+
   /**
    * Wrap this in a custom "FileRegion" which allows us to transfer over 2 GB.
    */
@@ -171,6 +211,8 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
 }
 
 private[spark] object ChunkedByteBuffer {
+  private val CHUNK_BUFFER_SIZE: Int = 1024 * 1024
+  private val MINIMUM_CHUNK_BUFFER_SIZE: Int = 1024
 
   def fromManagedBuffer(data: ManagedBuffer): ChunkedByteBuffer = {
     data match {
@@ -207,6 +249,18 @@ private[spark] object ChunkedByteBuffer {
     }
     out.toChunkedByteBuffer
   }
+
+  /**
+   * Try to estimate appropriate chunk size so that it's not too large (waste memory) or too
+   * small (too many segments)
+   */
+  def estimateBufferChunkSize(estimatedSize: Long = -1): Int = {
+    if (estimatedSize < 0) {
+      CHUNK_BUFFER_SIZE
+    } else {
+      Math.max(Math.min(estimatedSize, CHUNK_BUFFER_SIZE).toInt, MINIMUM_CHUNK_BUFFER_SIZE)
+    }
+  }
 }
 
 /**
diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
index 083c5e696b7..68b181de292 100644
--- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.io
 
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
 import java.nio.ByteBuffer
 
 import com.google.common.io.ByteStreams
@@ -28,6 +29,18 @@ import org.apache.spark.util.io.ChunkedByteBuffer
 
 class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext {
 
+  /**
+   * compare two ChunkedByteBuffer:
+   * - chunks nums equal
+   * - each chunk's content
+   */
+  def assertBufferEqual(buffer1: ChunkedByteBuffer, buffer2: ChunkedByteBuffer): Unit = {
+    assert(buffer1.chunks.length == buffer2.chunks.length)
+    assert(buffer1.chunks.zip(buffer2.chunks).forall {
+      case (chunk1, chunk2) => chunk1 == chunk2
+    })
+  }
+
   test("no chunks") {
     val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer])
     assert(emptyChunkedByteBuffer.size === 0)
@@ -69,6 +82,43 @@ class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext {
     }
   }
 
+  test("Externalizable: writeExternal() and readExternal()") {
+    // intentionally generate arrays of different len, in order to verify the chunks layout
+    // is preserved after ser/deser
+    val byteArrays = (1 to 15).map(i => (0 until i).map(_.toByte).toArray)
+    val chunkedByteBuffer = new ChunkedByteBuffer(byteArrays.map(ByteBuffer.wrap).toArray)
+    val baos = new ByteArrayOutputStream()
+    val objOut = new ObjectOutputStream(baos)
+    chunkedByteBuffer.writeExternal(objOut)
+    objOut.close()
+    assert(chunkedByteBuffer.chunks.forall(_.position() == 0))
+
+    val chunkedByteBuffer2 = {
+      val tmp = new ChunkedByteBuffer
+      tmp.readExternal(new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray)))
+      tmp
+    }
+    assertBufferEqual(chunkedByteBuffer, chunkedByteBuffer2)
+  }
+
+  test(
+    "Externalizable: writeExternal() and readExternal() should handle off-heap buffer properly") {
+    val chunkedByteBuffer = new ChunkedByteBuffer(
+      (0 until 10).map(_ => ByteBuffer.allocateDirect(10)).toArray)
+    val baos = new ByteArrayOutputStream()
+    val objOut = new ObjectOutputStream(baos)
+    chunkedByteBuffer.writeExternal(objOut)
+    objOut.close()
+
+    val chunkedByteBuffer2 = {
+      val tmp = new ChunkedByteBuffer
+      tmp.readExternal(new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray)))
+      tmp
+    }
+
+    assertBufferEqual(chunkedByteBuffer, chunkedByteBuffer2)
+  }
+
   test("toArray()") {
     val empty = ByteBuffer.wrap(Array.empty[Byte])
     val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
index 9ed26e71256..dac675fd738 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
@@ -321,7 +321,8 @@ private[spark] abstract class MockBackend(
   def taskSuccess(task: TaskDescription, result: Any): Unit = {
     val ser = env.serializer.newInstance()
     val resultBytes = ser.serialize(result)
-    val directResult = new DirectTaskResult(resultBytes, Seq(), Array()) // no accumulator updates
+    // no accumulator updates
+    val directResult = new DirectTaskResult(resultBytes, Seq(), Array[Long]())
     taskUpdate(task, TaskState.FINISHED, directResult)
   }
 
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index 1583d3b96ee..1f61fab3e07 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -153,7 +153,7 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
       override def canFetchMoreResults(size: Long): Boolean = false
     }
     val indirectTaskResult = IndirectTaskResult(TaskResultBlockId(0), 0)
-    val directTaskResult = new DirectTaskResult(ByteBuffer.allocate(0), Nil, Array())
+    val directTaskResult = new DirectTaskResult(ByteBuffer.allocate(0), Nil, Array[Long]())
     val ser = sc.env.closureSerializer.newInstance()
     val serializedIndirect = ser.serialize(indirectTaskResult)
     val serializedDirect = ser.serialize(directTaskResult)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 4e9e9755e85..b81f85bd1d7 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -761,11 +761,13 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext
         }
         // End the other task of the taskset, doesn't matter whether it succeeds or fails.
         val otherTask = tasks(1)
-        val result = new DirectTaskResult[Int](valueSer.serialize(otherTask.taskId), Seq(), Array())
+        val result = new DirectTaskResult[Int](valueSer.serialize(otherTask.taskId), Seq(),
+          Array[Long]())
         tsm.handleSuccessfulTask(otherTask.taskId, result)
       } else {
         tasks.foreach { task =>
-          val result = new DirectTaskResult[Int](valueSer.serialize(task.taskId), Seq(), Array())
+          val result = new DirectTaskResult[Int](valueSer.serialize(task.taskId), Seq(),
+            Array[Long]())
           tsm.handleSuccessfulTask(task.taskId, result)
         }
       }
@@ -2131,7 +2133,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext
     assert(2 === taskDescriptions.length)
 
     val ser = sc.env.serializer.newInstance()
-    val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(), Array.empty)
+    val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(), Array.empty[Long])
     val resultBytes = ser.serialize(directResult)
 
     val busyTask = new Runnable {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 32a43b093ee..2dc7f0d0dfa 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -882,7 +882,7 @@ class TaskSetManagerSuite
     assert(manager.runningTasks === 2)
     assert(manager.isZombie === false)
 
-    val directTaskResult = new DirectTaskResult[String](null, Seq(), Array()) {
+    val directTaskResult = new DirectTaskResult[String]() {
       override def value(resultSer: SerializerInstance): String = ""
     }
     // Complete one copy of the task, which should result in the task set manager
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala
index 25f0b19c980..41c1131a280 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala
@@ -18,8 +18,6 @@
 package org.apache.spark.serializer
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.LocalSparkContext._
-import org.apache.spark.SparkContext
 import org.apache.spark.SparkException
 import org.apache.spark.internal.config._
 import org.apache.spark.internal.config.Kryo._
@@ -34,9 +32,10 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite {
     conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
     conf.set(KRYO_SERIALIZER_BUFFER_SIZE.key, "1m")
     conf.set(KRYO_SERIALIZER_MAX_BUFFER_SIZE.key, "1m")
-    withSpark(new SparkContext("local", "test", conf)) { sc =>
-      intercept[SparkException](sc.parallelize(x).collect())
-    }
+
+    val ser = new KryoSerializer(conf)
+    val serInstance = ser.newInstance()
+    intercept[SparkException](serInstance.serialize(x))
   }
 
   test("kryo with resizable output buffer should succeed on large array") {
@@ -44,8 +43,9 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite {
     conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
     conf.set(KRYO_SERIALIZER_BUFFER_SIZE.key, "1m")
     conf.set(KRYO_SERIALIZER_MAX_BUFFER_SIZE.key, "2m")
-    withSpark(new SparkContext("local", "test", conf)) { sc =>
-      assert(sc.parallelize(x).collect() === x)
-    }
+
+    val ser = new KryoSerializer(conf)
+    val serInstance = ser.newInstance()
+    serInstance.serialize(x)
   }
 }
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 18667d1efea..a63f52e5430 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -1175,6 +1175,7 @@ object Unidoc {
         !f.getCanonicalPath.contains("org/apache/spark/unsafe/types/CalendarInterval")))
       .map(_.filterNot(_.getCanonicalPath.contains("python")))
       .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection")))
+      .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/io")))
       .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore")))
       .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst")))
       .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect")))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 549bc70bac7..a302298d99c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -46,9 +46,9 @@ object Cast extends QueryErrorsBase {
    * As per section 6.13 "cast specification" in "Information technology — Database languages " +
    * "- SQL — Part 2: Foundation (SQL/Foundation)":
    * If the <cast operand> is a <value expression>, then the valid combinations of TD and SD
-   * in a <cast specification> are given by the following table. “Y” indicates that the
-   * combination is syntactically valid without restriction; “M” indicates that the combination
-   * is valid subject to other Syntax Rules in this Sub- clause being satisfied; and “N” indicates
+   * in a <cast specification> are given by the following table. "Y" indicates that the
+   * combination is syntactically valid without restriction; "M" indicates that the combination
+   * is valid subject to other Syntax Rules in this Sub- clause being satisfied; and "N" indicates
    * that the combination is not valid:
    * SD                   TD
    *     EN AN C D T TS YM DT BO UDT B RT CT RW
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index a56732fdc12..4aca67a17cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.execution
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.io.{DataInputStream, DataOutputStream}
+import java.nio.ByteBuffer
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.mutable.{ArrayBuffer, ListBuffer}
@@ -38,6 +39,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.NextIterator
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
 
 object SparkPlan {
   /** The original [[LogicalPlan]] from which this [[SparkPlan]] is converted. */
@@ -336,13 +338,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
    * compressed.
    */
   private def getByteArrayRdd(
-      n: Int = -1, takeFromEnd: Boolean = false): RDD[(Long, Array[Byte])] = {
+      n: Int = -1, takeFromEnd: Boolean = false): RDD[(Long, ChunkedByteBuffer)] = {
     execute().mapPartitionsInternal { iter =>
       var count = 0
       val buffer = new Array[Byte](4 << 10)  // 4K
       val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
-      val bos = new ByteArrayOutputStream()
-      val out = new DataOutputStream(codec.compressedOutputStream(bos))
+      val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
+      val out = new DataOutputStream(codec.compressedOutputStream(cbbos))
 
       if (takeFromEnd && n > 0) {
         // To collect n from the last, we should anyway read everything with keeping the n.
@@ -371,19 +373,19 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
       out.writeInt(-1)
       out.flush()
       out.close()
-      Iterator((count, bos.toByteArray))
+      Iterator((count, cbbos.toChunkedByteBuffer))
     }
   }
 
   /**
    * Decodes the byte arrays back to UnsafeRows and put them into buffer.
    */
-  private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = {
+  private def decodeUnsafeRows(bytes: ChunkedByteBuffer): Iterator[InternalRow] = {
     val nFields = schema.length
 
     val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
-    val bis = new ByteArrayInputStream(bytes)
-    val ins = new DataInputStream(codec.compressedInputStream(bis))
+    val cbbis = bytes.toInputStream()
+    val ins = new DataInputStream(codec.compressedInputStream(cbbis))
 
     new NextIterator[InternalRow] {
       private var sizeOfNextRow = ins.readInt()
@@ -503,8 +505,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
         parts
       }
       val sc = sparkContext
-      val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) =>
-        if (it.hasNext) it.next() else (0L, Array.emptyByteArray), partsToScan)
+      val res = sc.runJob(childRDD, (it: Iterator[(Long, ChunkedByteBuffer)]) =>
+        if (it.hasNext) it.next() else (0L, new ChunkedByteBuffer()), partsToScan)
 
       var i = 0
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 8f5740e65ed..370e5ca546b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -20,13 +20,16 @@ package org.apache.spark.sql
 import java.io.{Externalizable, ObjectInput, ObjectOutput}
 import java.sql.{Date, Timestamp}
 
+import scala.util.Random
+
 import org.apache.hadoop.fs.{Path, PathFilter}
 import org.scalatest.Assertions._
 import org.scalatest.exceptions.TestFailedException
 import org.scalatest.prop.TableDrivenPropertyChecks._
 
-import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.{SparkConf, SparkException, TaskContext}
 import org.apache.spark.TestUtils.withListener
+import org.apache.spark.internal.config.MAX_RESULT_SIZE
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
 import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample}
 import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
@@ -2228,6 +2231,31 @@ class DatasetSuite extends QueryTest
   }
 }
 
+class DatasetLargeResultCollectingSuite extends QueryTest
+  with SharedSparkSession {
+
+  override protected def sparkConf: SparkConf = super.sparkConf.set(MAX_RESULT_SIZE.key, "4g")
+  test("collect data with single partition larger than 2GB bytes array limit") {
+    // This test requires large memory and leads to OOM in Github Action so we skip it. Developer
+    // should verify it in local build.
+    assume(!sys.env.contains("GITHUB_ACTIONS"))
+    import org.apache.spark.sql.functions.udf
+
+    val genData = udf((id: Long, bytesSize: Int) => {
+      val rand = new Random(id)
+      val arr = new Array[Byte](bytesSize)
+      rand.nextBytes(arr)
+      arr
+    })
+
+    spark.udf.register("genData", genData.asNondeterministic())
+    // create data of size >2GB in single partition, which exceeds the byte array limit
+    // random gen to make sure it's poorly compressed
+    val df = spark.range(0, 2100, 1, 1).selectExpr("id", s"genData(id, 1000000) as data")
+    val res = df.queryExecution.executedPlan.executeCollect()
+  }
+}
+
 case class Bar(a: Int)
 
 object AssertExecutionId {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org