You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ul...@apache.org on 2023/04/10 01:44:06 UTC

[kyuubi] branch branch-1.7 updated: [KYUUBI #4662] [ARROW] Arrow serialization should not introduce extra shuffle for outermost limit

This is an automated email from the ASF dual-hosted git repository.

ulyssesyou pushed a commit to branch branch-1.7
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/branch-1.7 by this push:
     new a8fecd1e7 [KYUUBI #4662] [ARROW] Arrow serialization should not introduce extra shuffle for outermost limit
a8fecd1e7 is described below

commit a8fecd1e7cd484724db02ea759ce7c1984ed1185
Author: Fu Chen <cf...@gmail.com>
AuthorDate: Mon Apr 10 09:43:30 2023 +0800

    [KYUUBI #4662] [ARROW] Arrow serialization should not introduce extra shuffle for outermost limit
    
    ### _Why are the changes needed?_
    
    The fundamental concept is to execute a job similar to the way in which `CollectLimitExec.executeCollect()` operates.
    
    ```sql
    select * from parquet.`parquet/tpcds/sf1000/catalog_sales` limit 20;
    ```
    
    Before this PR:
    ![截屏2023-04-04 下午3 20 34](https://user-images.githubusercontent.com/8537877/229717946-87c480c6-9550-4d00-bc96-14d59d7ce9f7.png)
    
    ![截屏2023-04-04 下午3 20 54](https://user-images.githubusercontent.com/8537877/229717973-bf6da5af-74e7-422a-b9fa-8b7bebd43320.png)
    
    After this PR:
    
    ![截屏2023-04-04 下午3 17 05](https://user-images.githubusercontent.com/8537877/229718016-6218d019-b223-4deb-b596-6f0431d33d2a.png)
    
    ![截屏2023-04-04 下午3 17 16](https://user-images.githubusercontent.com/8537877/229718046-ea07cd1f-5ffc-42ba-87d5-08085feb4b16.png)
    
    ### _How was this patch tested?_
    - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible
    
    - [ ] Add screenshots for manual tests if appropriate
    
    - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request
    
    Closes #4662 from cfmcgrady/arrow-collect-limit-exec-2.
    
    Closes #4662
    
    82c912ed6 [Fu Chen] close vector
    130bcb141 [Fu Chen] finally close
    facc13f78 [Fu Chen] exclude rule OptimizeLimitZero
    370083910 [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x
    6064ab961 [Fu Chen] limit = 0 test case
    6d596fcce [Fu Chen] address comment
    8280783c3 [Fu Chen] add `isStaticConfigKey` to adapt Spark-3.1.x
    22cc70fba [Fu Chen] add ut
    b72bc6fb2 [Fu Chen] add offset support to adapt Spark-3.4.x
    9ffb44fb2 [Fu Chen] make toBatchIterator private
    c83cf3f5e [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x
    573a262ed [Fu Chen] fix
    4cef20481 [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x
    d70aee36b [Fu Chen] SparkPlan.session -> SparkSession.active to adapt Spark-3.1.x
    e3bf84c03 [Fu Chen] refactor
    81886f01c [Fu Chen] address comment
    2286afc6b [Fu Chen] reflective calla AdaptiveSparkPlanExec.finalPhysicalPlan
    03d074732 [Fu Chen] address comment
    25e4f056c [Fu Chen] add docs
    885cf2c71 [Fu Chen] infer row size by schema.defaultSize
    4e7ca54df [Fu Chen] unnecessarily changes
    ee5a7567a [Fu Chen] revert unnecessarily changes
    6c5b1eb61 [Fu Chen] add ut
    4212a8967 [Fu Chen] refactor and add ut
    ed8c6928b [Fu Chen] refactor
    008867122 [Fu Chen] refine
    8593d856a [Fu Chen] driver slice last batch
    a5849430a [Fu Chen] arrow take
    
    Authored-by: Fu Chen <cf...@gmail.com>
    Signed-off-by: ulyssesyou <ul...@apache.org>
    (cherry picked from commit 1a651254cb9dec71082e9cfadd58a4dbbd976d1f)
    Signed-off-by: ulyssesyou <ul...@apache.org>
---
 externals/kyuubi-spark-sql-engine/pom.xml          |   7 +
 .../engine/spark/operation/ExecuteStatement.scala  |  32 +-
 .../execution/arrow/KyuubiArrowConverters.scala    | 321 +++++++++++++++++++++
 .../spark/sql/kyuubi/SparkDatasetHelper.scala      | 160 +++++++++-
 .../operation/SparkArrowbasedOperationSuite.scala  | 260 ++++++++++++++++-
 .../apache/spark/KyuubiSparkContextHelper.scala    |   2 +
 pom.xml                                            |   4 +-
 7 files changed, 753 insertions(+), 33 deletions(-)

diff --git a/externals/kyuubi-spark-sql-engine/pom.xml b/externals/kyuubi-spark-sql-engine/pom.xml
index 579dd9ca5..78821c720 100644
--- a/externals/kyuubi-spark-sql-engine/pom.xml
+++ b/externals/kyuubi-spark-sql-engine/pom.xml
@@ -65,6 +65,13 @@
             <scope>provided</scope>
         </dependency>
 
+        <dependency>
+            <groupId>org.apache.spark</groupId>
+            <artifactId>spark-sql_${scala.binary.version}</artifactId>
+            <type>test-jar</type>
+            <scope>test</scope>
+        </dependency>
+
         <dependency>
             <groupId>org.apache.spark</groupId>
             <artifactId>spark-repl_${scala.binary.version}</artifactId>
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
index fa517a8b1..015c4ba4a 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
@@ -21,10 +21,8 @@ import java.util.concurrent.RejectedExecutionException
 
 import scala.collection.JavaConverters._
 
-import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.kyuubi.SparkDatasetHelper
+import org.apache.spark.sql.kyuubi.SparkDatasetHelper._
 import org.apache.spark.sql.types._
 
 import org.apache.kyuubi.{KyuubiSQLException, Logging}
@@ -179,34 +177,15 @@ class ArrowBasedExecuteStatement(
   extends ExecuteStatement(session, statement, shouldRunAsync, queryTimeout, incrementalCollect) {
 
   override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
-    collectAsArrow(convertComplexType(resultDF)) { rdd =>
-      rdd.toLocalIterator
-    }
+    toArrowBatchLocalIterator(convertComplexType(resultDF))
   }
 
   override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
-    collectAsArrow(convertComplexType(resultDF)) { rdd =>
-      rdd.collect()
-    }
+    executeCollect(convertComplexType(resultDF))
   }
 
   override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
-    // this will introduce shuffle and hurt performance
-    val limitedResult = resultDF.limit(maxRows)
-    collectAsArrow(convertComplexType(limitedResult)) { rdd =>
-      rdd.collect()
-    }
-  }
-
-  /**
-   * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
-   * operation, so that we can track the arrow-based queries on the UI tab.
-   */
-  private def collectAsArrow[T](df: DataFrame)(action: RDD[Array[Byte]] => T): T = {
-    SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
-      df.queryExecution.executedPlan.resetMetrics()
-      action(SparkDatasetHelper.toArrowBatchRdd(df))
-    }
+    executeCollect(convertComplexType(resultDF.limit(maxRows)))
   }
 
   override protected def isArrowBasedOperation: Boolean = true
@@ -214,7 +193,6 @@ class ArrowBasedExecuteStatement(
   override val resultFormat = "arrow"
 
   private def convertComplexType(df: DataFrame): DataFrame = {
-    SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
+    convertTopLevelComplexTypeToHiveString(df, timestampAsString)
   }
-
 }
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala
new file mode 100644
index 000000000..dd6163ec9
--- /dev/null
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.arrow
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.channels.Channels
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel}
+import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.CollectLimitExec
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.Utils
+
+object KyuubiArrowConverters extends SQLConfHelper with Logging {
+
+  type Batch = (Array[Byte], Long)
+
+  /**
+   * this method is to slice the input Arrow record batch byte array `bytes`, starting from `start`
+   * and taking `length` number of elements.
+   */
+  def slice(
+      schema: StructType,
+      timeZoneId: String,
+      bytes: Array[Byte],
+      start: Int,
+      length: Int): Array[Byte] = {
+    val in = new ByteArrayInputStream(bytes)
+    val out = new ByteArrayOutputStream(bytes.length)
+
+    var vectorSchemaRoot: VectorSchemaRoot = null
+    var slicedVectorSchemaRoot: VectorSchemaRoot = null
+
+    val sliceAllocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "slice",
+      0,
+      Long.MaxValue)
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    vectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, sliceAllocator)
+    try {
+      val recordBatch = MessageSerializer.deserializeRecordBatch(
+        new ReadChannel(Channels.newChannel(in)),
+        sliceAllocator)
+      val vectorLoader = new VectorLoader(vectorSchemaRoot)
+      vectorLoader.load(recordBatch)
+      recordBatch.close()
+      slicedVectorSchemaRoot = vectorSchemaRoot.slice(start, length)
+
+      val unloader = new VectorUnloader(slicedVectorSchemaRoot)
+      val writeChannel = new WriteChannel(Channels.newChannel(out))
+      val batch = unloader.getRecordBatch()
+      MessageSerializer.serialize(writeChannel, batch)
+      batch.close()
+      out.toByteArray()
+    } finally {
+      in.close()
+      out.close()
+      if (vectorSchemaRoot != null) {
+        vectorSchemaRoot.getFieldVectors.asScala.foreach(_.close())
+        vectorSchemaRoot.close()
+      }
+      if (slicedVectorSchemaRoot != null) {
+        slicedVectorSchemaRoot.getFieldVectors.asScala.foreach(_.close())
+        slicedVectorSchemaRoot.close()
+      }
+      sliceAllocator.close()
+    }
+  }
+
+  /**
+   * Forked from `org.apache.spark.sql.execution.SparkPlan#executeTake()`, the algorithm can be
+   * summarized in the following steps:
+   * 1. If the limit specified in the CollectLimitExec object is 0, the function returns an empty
+   *    array of batches.
+   * 2. Otherwise, execute the child query plan of the CollectLimitExec object to obtain an RDD of
+   *    data to collect.
+   * 3. Use an iterative approach to collect data in batches until the specified limit is reached.
+   *    In each iteration, it selects a subset of the partitions of the RDD to scan and tries to
+   *    collect data from them.
+   * 4. For each partition subset, we use the runJob method of the Spark context to execute a
+   *    closure that scans the partition data and converts it to Arrow batches.
+   * 5. Check if the collected data reaches the specified limit. If not, it selects another subset
+   *    of partitions to scan and repeats the process until the limit is reached or all partitions
+   *    have been scanned.
+   * 6. Return an array of all the collected Arrow batches.
+   *
+   * Note that:
+   * 1. The returned Arrow batches row count >= limit, if the input df has more than the `limit`
+   *    row count
+   * 2. We don't implement the `takeFromEnd` logical
+   *
+   * @return
+   */
+  def takeAsArrowBatches(
+      collectLimitExec: CollectLimitExec,
+      maxRecordsPerBatch: Long,
+      maxEstimatedBatchSize: Long,
+      timeZoneId: String): Array[Batch] = {
+    val n = collectLimitExec.limit
+    val schema = collectLimitExec.schema
+    if (n == 0) {
+      return new Array[Batch](0)
+    } else {
+      val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
+      // TODO: refactor and reuse the code from RDD's take()
+      val childRDD = collectLimitExec.child.execute()
+      val buf = new ArrayBuffer[Batch]
+      var bufferedRowSize = 0L
+      val totalParts = childRDD.partitions.length
+      var partsScanned = 0
+      while (bufferedRowSize < n && partsScanned < totalParts) {
+        // The number of partitions to try in this iteration. It is ok for this number to be
+        // greater than totalParts because we actually cap it at totalParts in runJob.
+        var numPartsToTry = limitInitialNumPartitions
+        if (partsScanned > 0) {
+          // If we didn't find any rows after the previous iteration, multiply by
+          // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
+          // to try, but overestimate it by 50%. We also cap the estimation in the end.
+          if (buf.isEmpty) {
+            numPartsToTry = partsScanned * limitScaleUpFactor
+          } else {
+            val left = n - bufferedRowSize
+            // As left > 0, numPartsToTry is always >= 1
+            numPartsToTry = Math.ceil(1.5 * left * partsScanned / bufferedRowSize).toInt
+            numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
+          }
+        }
+
+        val partsToScan =
+          partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
+
+        // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we
+        // drop Spark-3.1.x support.
+        val sc = SparkSession.active.sparkContext
+        val res = sc.runJob(
+          childRDD,
+          (it: Iterator[InternalRow]) => {
+            val batches = toBatchIterator(
+              it,
+              schema,
+              maxRecordsPerBatch,
+              maxEstimatedBatchSize,
+              n,
+              timeZoneId)
+            batches.map(b => b -> batches.rowCountInLastBatch).toArray
+          },
+          partsToScan)
+
+        var i = 0
+        while (bufferedRowSize < n && i < res.length) {
+          var j = 0
+          val batches = res(i)
+          while (j < batches.length && n > bufferedRowSize) {
+            val batch = batches(j)
+            val (_, batchSize) = batch
+            buf += batch
+            bufferedRowSize += batchSize
+            j += 1
+          }
+          i += 1
+        }
+        partsScanned += partsToScan.size
+      }
+
+      buf.toArray
+    }
+  }
+
+  /**
+   * Spark introduced the config `spark.sql.limit.initialNumPartitions` since 3.4.0. see SPARK-40211
+   */
+  private def limitInitialNumPartitions: Int = {
+    conf.getConfString("spark.sql.limit.initialNumPartitions", "1")
+      .toInt
+  }
+
+  /**
+   * Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]],
+   * each output arrow batch contains this batch row count.
+   */
+  private def toBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Long,
+      maxEstimatedBatchSize: Long,
+      limit: Long,
+      timeZoneId: String): ArrowBatchIterator = {
+    new ArrowBatchIterator(
+      rowIter,
+      schema,
+      maxRecordsPerBatch,
+      maxEstimatedBatchSize,
+      limit,
+      timeZoneId,
+      TaskContext.get)
+  }
+
+  /**
+   * This class ArrowBatchIterator is derived from
+   * [[org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator]],
+   * with two key differences:
+   *   1. there is no requirement to write the schema at the batch header
+   *   2. iteration halts when `rowCount` equals `limit`
+   */
+  private[sql] class ArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Long,
+      maxEstimatedBatchSize: Long,
+      limit: Long,
+      timeZoneId: String,
+      context: TaskContext)
+    extends Iterator[Array[Byte]] {
+
+    protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    private val allocator =
+      ArrowUtils.rootAllocator.newChildAllocator(
+        s"to${this.getClass.getSimpleName}",
+        0,
+        Long.MaxValue)
+
+    private val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    protected val unloader = new VectorUnloader(root)
+    protected val arrowWriter = ArrowWriter.create(root)
+
+    Option(context).foreach {
+      _.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || {
+      root.close()
+      allocator.close()
+      false
+    }
+
+    var rowCountInLastBatch: Long = 0
+    var rowCount: Long = 0
+
+    override def next(): Array[Byte] = {
+      val out = new ByteArrayOutputStream()
+      val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+      rowCountInLastBatch = 0
+      var estimatedBatchSize = 0L
+      Utils.tryWithSafeFinally {
+
+        // Always write the first row.
+        while (rowIter.hasNext && (
+            // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller.
+            // If the size in bytes is positive (set properly), always write the first row.
+            rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 ||
+              // If the size in bytes of rows are 0 or negative, unlimit it.
+              estimatedBatchSize <= 0 ||
+              estimatedBatchSize < maxEstimatedBatchSize ||
+              // If the size of rows are 0 or negative, unlimit it.
+              maxRecordsPerBatch <= 0 ||
+              rowCountInLastBatch < maxRecordsPerBatch ||
+              rowCount < limit)) {
+          val row = rowIter.next()
+          arrowWriter.write(row)
+          estimatedBatchSize += (row match {
+            case ur: UnsafeRow => ur.getSizeInBytes
+            // Trying to estimate the size of the current row
+            case _: InternalRow => schema.defaultSize
+          })
+          rowCountInLastBatch += 1
+          rowCount += 1
+        }
+        arrowWriter.finish()
+        val batch = unloader.getRecordBatch()
+        MessageSerializer.serialize(writeChannel, batch)
+
+        // Always write the Ipc options at the end.
+        ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
+
+        batch.close()
+      } {
+        arrowWriter.reset()
+      }
+
+      out.toByteArray
+    }
+  }
+
+  // for testing
+  def fromBatchIterator(
+      arrowBatchIter: Iterator[Array[Byte]],
+      schema: StructType,
+      timeZoneId: String,
+      context: TaskContext): Iterator[InternalRow] = {
+    ArrowConverters.fromBatchIterator(arrowBatchIter, schema, timeZoneId, context)
+  }
+}
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala
index 1a5429373..1c8d32c48 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala
@@ -17,18 +17,75 @@
 
 package org.apache.spark.sql.kyuubi
 
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.{ByteUnit, JavaUtils}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.arrow.{ArrowConverters, KyuubiArrowConverters}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
+import org.apache.kyuubi.engine.spark.KyuubiSparkUtil
 import org.apache.kyuubi.engine.spark.schema.RowSet
+import org.apache.kyuubi.reflection.DynMethods
+
+object SparkDatasetHelper extends Logging {
+
+  def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) {
+    executeArrowBatchCollect(df.queryExecution.executedPlan)
+  }
+
+  def executeArrowBatchCollect: SparkPlan => Array[Array[Byte]] = {
+    case adaptiveSparkPlan: AdaptiveSparkPlanExec =>
+      executeArrowBatchCollect(finalPhysicalPlan(adaptiveSparkPlan))
+    // TODO: avoid extra shuffle if `offset` > 0
+    case collectLimit: CollectLimitExec if offset(collectLimit) > 0 =>
+      logWarning("unsupported offset > 0, an extra shuffle will be introduced.")
+      toArrowBatchRdd(collectLimit).collect()
+    case collectLimit: CollectLimitExec if collectLimit.limit >= 0 =>
+      doCollectLimit(collectLimit)
+    case collectLimit: CollectLimitExec if collectLimit.limit < 0 =>
+      executeArrowBatchCollect(collectLimit.child)
+    case plan: SparkPlan =>
+      toArrowBatchRdd(plan).collect()
+  }
 
-object SparkDatasetHelper {
   def toArrowBatchRdd[T](ds: Dataset[T]): RDD[Array[Byte]] = {
     ds.toArrowBatchRdd
   }
 
+  /**
+   * Forked from [[Dataset.toArrowBatchRdd(plan: SparkPlan)]].
+   * Convert to an RDD of serialized ArrowRecordBatches.
+   */
+  def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
+    val schemaCaptured = plan.schema
+    // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we
+    // drop Spark-3.1.x support.
+    val maxRecordsPerBatch = SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = SparkSession.active.sessionState.conf.sessionLocalTimeZone
+    plan.execute().mapPartitionsInternal { iter =>
+      val context = TaskContext.get()
+      ArrowConverters.toBatchIterator(
+        iter,
+        schemaCaptured,
+        maxRecordsPerBatch,
+        timeZoneId,
+        context)
+    }
+  }
+
+  def toArrowBatchLocalIterator(df: DataFrame): Iterator[Array[Byte]] = {
+    withNewExecutionId(df) {
+      toArrowBatchRdd(df).toLocalIterator
+    }
+  }
+
   def convertTopLevelComplexTypeToHiveString(
       df: DataFrame,
       timestampAsString: Boolean): DataFrame = {
@@ -68,11 +125,108 @@ object SparkDatasetHelper {
    * Fork from Apache Spark-3.3.1 org.apache.spark.sql.catalyst.util.quoteIfNeeded to adapt to
    * Spark-3.1.x
    */
-  def quoteIfNeeded(part: String): String = {
+  private def quoteIfNeeded(part: String): String = {
     if (part.matches("[a-zA-Z0-9_]+") && !part.matches("\\d+")) {
       part
     } else {
       s"`${part.replace("`", "``")}`"
     }
   }
+
+  private lazy val maxBatchSize: Long = {
+    // respect spark connect config
+    KyuubiSparkUtil.globalSparkContext
+      .getConf
+      .getOption("spark.connect.grpc.arrow.maxBatchSize")
+      .orElse(Option("4m"))
+      .map(JavaUtils.byteStringAs(_, ByteUnit.MiB))
+      .get
+  }
+
+  private def doCollectLimit(collectLimit: CollectLimitExec): Array[Array[Byte]] = {
+    // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we
+    // drop Spark-3.1.x support.
+    val timeZoneId = SparkSession.active.sessionState.conf.sessionLocalTimeZone
+    val maxRecordsPerBatch = SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch
+
+    val batches = KyuubiArrowConverters.takeAsArrowBatches(
+      collectLimit,
+      maxRecordsPerBatch,
+      maxBatchSize,
+      timeZoneId)
+
+    // note that the number of rows in the returned arrow batches may be >= `limit`, perform
+    // the slicing operation of result
+    val result = ArrayBuffer[Array[Byte]]()
+    var i = 0
+    var rest = collectLimit.limit
+    while (i < batches.length && rest > 0) {
+      val (batch, size) = batches(i)
+      if (size <= rest) {
+        result += batch
+        // returned ArrowRecordBatch has less than `limit` row count, safety to do conversion
+        rest -= size.toInt
+      } else { // size > rest
+        result += KyuubiArrowConverters.slice(collectLimit.schema, timeZoneId, batch, 0, rest)
+        rest = 0
+      }
+      i += 1
+    }
+    result.toArray
+  }
+
+  /**
+   * This method provides a reflection-based implementation of
+   * [[AdaptiveSparkPlanExec.finalPhysicalPlan]] that enables us to adapt to the Spark runtime
+   * without patching SPARK-41914.
+   *
+   * TODO: Once we drop support for Spark 3.1.x, we can directly call
+   * [[AdaptiveSparkPlanExec.finalPhysicalPlan]].
+   */
+  def finalPhysicalPlan(adaptiveSparkPlanExec: AdaptiveSparkPlanExec): SparkPlan = {
+    withFinalPlanUpdate(adaptiveSparkPlanExec, identity)
+  }
+
+  /**
+   * A reflection-based implementation of [[AdaptiveSparkPlanExec.withFinalPlanUpdate]].
+   */
+  private def withFinalPlanUpdate[T](
+      adaptiveSparkPlanExec: AdaptiveSparkPlanExec,
+      fun: SparkPlan => T): T = {
+    val getFinalPhysicalPlan = DynMethods.builder("getFinalPhysicalPlan")
+      .hiddenImpl(adaptiveSparkPlanExec.getClass)
+      .build()
+    val plan = getFinalPhysicalPlan.invoke[SparkPlan](adaptiveSparkPlanExec)
+    val result = fun(plan)
+    val finalPlanUpdate = DynMethods.builder("finalPlanUpdate")
+      .hiddenImpl(adaptiveSparkPlanExec.getClass)
+      .build()
+    finalPlanUpdate.invoke[Unit](adaptiveSparkPlanExec)
+    result
+  }
+
+  /**
+   * offset support was add since Spark-3.4(set SPARK-28330), to ensure backward compatibility with
+   * earlier versions of Spark, this function uses reflective calls to the "offset".
+   */
+  private def offset(collectLimitExec: CollectLimitExec): Int = {
+    Option(
+      DynMethods.builder("offset")
+        .impl(collectLimitExec.getClass)
+        .orNoop()
+        .build()
+        .invoke[Int](collectLimitExec))
+      .getOrElse(0)
+  }
+
+  /**
+   * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
+   * operation, so that we can track the arrow-based queries on the UI tab.
+   */
+  private def withNewExecutionId[T](df: DataFrame)(body: => T): T = {
+    SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
+      df.queryExecution.executedPlan.resetMetrics()
+      body
+    }
+  }
 }
diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
index ae6237bb5..2ef29b398 100644
--- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
+++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
@@ -18,16 +18,28 @@
 package org.apache.kyuubi.engine.spark.operation
 
 import java.sql.Statement
+import java.util.{Set => JSet}
 
 import org.apache.spark.KyuubiSparkContextHelper
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
+import org.apache.spark.sql.{QueryTest, Row, SparkSession}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
-import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
+import org.apache.spark.sql.execution.exchange.Exchange
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.kyuubi.SparkDatasetHelper
 import org.apache.spark.sql.util.QueryExecutionListener
 
+import org.apache.kyuubi.KyuubiException
 import org.apache.kyuubi.config.KyuubiConf
 import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine}
 import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
 import org.apache.kyuubi.operation.SparkDataTypeTests
+import org.apache.kyuubi.reflection.DynFields
 
 class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests {
 
@@ -138,6 +150,155 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
     assert(metrics("numOutputRows").value === 1)
   }
 
+  test("SparkDatasetHelper.executeArrowBatchCollect should return expect row count") {
+    val returnSize = Seq(
+      0, // spark optimizer guaranty the `limit != 0`, it's just for the sanity check
+      7, // less than one partition
+      10, // equal to one partition
+      13, // between one and two partitions, run two jobs
+      20, // equal to two partitions
+      29, // between two and three partitions
+      1000, // all partitions
+      1001) // more than total row count
+
+    def runAndCheck(sparkPlan: SparkPlan, expectSize: Int): Unit = {
+      val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(sparkPlan)
+      val rows = KyuubiArrowConverters.fromBatchIterator(
+        arrowBinary.iterator,
+        sparkPlan.schema,
+        "",
+        KyuubiSparkContextHelper.dummyTaskContext())
+      assert(rows.size == expectSize)
+    }
+
+    val excludedRules = Seq(
+      "org.apache.spark.sql.catalyst.optimizer.EliminateLimits",
+      "org.apache.spark.sql.catalyst.optimizer.OptimizeLimitZero",
+      "org.apache.spark.sql.execution.adaptive.AQEPropagateEmptyRelation").mkString(",")
+    withSQLConf(
+      SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules,
+      SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) {
+      // aqe
+      // outermost AdaptiveSparkPlanExec
+      spark.range(1000)
+        .repartitionByRange(100, col("id"))
+        .createOrReplaceTempView("t_1")
+      spark.sql("select * from t_1")
+        .foreachPartition { p: Iterator[Row] =>
+          assert(p.length == 10)
+          ()
+        }
+      returnSize.foreach { size =>
+        val df = spark.sql(s"select * from t_1 limit $size")
+        val headPlan = df.queryExecution.executedPlan.collectLeaves().head
+        if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") {
+          assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec])
+          val finalPhysicalPlan =
+            SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec])
+          assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec])
+        }
+        if (size > 1000) {
+          runAndCheck(df.queryExecution.executedPlan, 1000)
+        } else {
+          runAndCheck(df.queryExecution.executedPlan, size)
+        }
+      }
+
+      // outermost CollectLimitExec
+      spark.range(0, 1000, 1, numPartitions = 100)
+        .createOrReplaceTempView("t_2")
+      spark.sql("select * from t_2")
+        .foreachPartition { p: Iterator[Row] =>
+          assert(p.length == 10)
+          ()
+        }
+      returnSize.foreach { size =>
+        val df = spark.sql(s"select * from t_2 limit $size")
+        val plan = df.queryExecution.executedPlan
+        assert(plan.isInstanceOf[CollectLimitExec])
+        if (size > 1000) {
+          runAndCheck(df.queryExecution.executedPlan, 1000)
+        } else {
+          runAndCheck(df.queryExecution.executedPlan, size)
+        }
+      }
+    }
+  }
+
+  test("aqe should work properly") {
+
+    val s = spark
+    import s.implicits._
+
+    spark.sparkContext.parallelize(
+      (1 to 100).map(i => TestData(i, i.toString))).toDF()
+      .createOrReplaceTempView("testData")
+    spark.sparkContext.parallelize(
+      TestData2(1, 1) ::
+        TestData2(1, 2) ::
+        TestData2(2, 1) ::
+        TestData2(2, 2) ::
+        TestData2(3, 1) ::
+        TestData2(3, 2) :: Nil,
+      2).toDF()
+      .createOrReplaceTempView("testData2")
+
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+      SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+      val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+        """
+          |SELECT * FROM(
+          |  SELECT * FROM testData join testData2 ON key = a where value = '1'
+          |) LIMIT 1
+          |""".stripMargin)
+      val smj = plan.collect { case smj: SortMergeJoinExec => smj }
+      val bhj = adaptivePlan.collect { case bhj: BroadcastHashJoinExec => bhj }
+      assert(smj.size == 1)
+      assert(bhj.size == 1)
+    }
+  }
+
+  test("result offset support") {
+    assume(SPARK_ENGINE_RUNTIME_VERSION > "3.3")
+    var numStages = 0
+    val listener = new SparkListener {
+      override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+        numStages = jobStart.stageInfos.length
+      }
+    }
+    withJdbcStatement() { statement =>
+      withSparkListener(listener) {
+        withPartitionedTable("t_3") {
+          statement.executeQuery("select * from t_3 limit 10 offset 10")
+        }
+        KyuubiSparkContextHelper.waitListenerBus(spark)
+      }
+    }
+    // the extra shuffle be introduced if the `offset` > 0
+    assert(numStages == 2)
+  }
+
+  test("arrow serialization should not introduce extra shuffle for outermost limit") {
+    var numStages = 0
+    val listener = new SparkListener {
+      override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+        numStages = jobStart.stageInfos.length
+      }
+    }
+    withJdbcStatement() { statement =>
+      withSparkListener(listener) {
+        withPartitionedTable("t_3") {
+          statement.executeQuery("select * from t_3 limit 1000")
+        }
+        KyuubiSparkContextHelper.waitListenerBus(spark)
+      }
+    }
+    // Should be only one stage since there is no shuffle.
+    assert(numStages == 1)
+  }
+
   private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
     val query =
       s"""
@@ -177,4 +338,101 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
       .allSessions()
       .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener))
   }
+
+  private def withSparkListener[T](listener: SparkListener)(body: => T): T = {
+    withAllSessions(s => s.sparkContext.addSparkListener(listener))
+    try {
+      body
+    } finally {
+      withAllSessions(s => s.sparkContext.removeSparkListener(listener))
+    }
+
+  }
+
+  private def withPartitionedTable[T](viewName: String)(body: => T): T = {
+    withAllSessions { spark =>
+      spark.range(0, 1000, 1, numPartitions = 100)
+        .createOrReplaceTempView(viewName)
+    }
+    try {
+      body
+    } finally {
+      withAllSessions { spark =>
+        spark.sql(s"DROP VIEW IF EXISTS $viewName")
+      }
+    }
+  }
+
+  private def withAllSessions(op: SparkSession => Unit): Unit = {
+    SparkSQLEngine.currentEngine.get
+      .backendService
+      .sessionManager
+      .allSessions()
+      .map(_.asInstanceOf[SparkSessionImpl].spark)
+      .foreach(op(_))
+  }
+
+  private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = {
+    val dfAdaptive = spark.sql(query)
+    val planBefore = dfAdaptive.queryExecution.executedPlan
+    val result = dfAdaptive.collect()
+    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+      val df = spark.sql(query)
+      QueryTest.checkAnswer(df, df.collect().toSeq)
+    }
+    val planAfter = dfAdaptive.queryExecution.executedPlan
+    val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
+    val exchanges = adaptivePlan.collect {
+      case e: Exchange => e
+    }
+    assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.")
+    (dfAdaptive.queryExecution.sparkPlan, adaptivePlan)
+  }
+
+  /**
+   * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
+   * configurations.
+   */
+  protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
+    val conf = SQLConf.get
+    val (keys, values) = pairs.unzip
+    val currentValues = keys.map { key =>
+      if (conf.contains(key)) {
+        Some(conf.getConfString(key))
+      } else {
+        None
+      }
+    }
+    (keys, values).zipped.foreach { (k, v) =>
+      if (isStaticConfigKey(k)) {
+        throw new KyuubiException(s"Cannot modify the value of a static config: $k")
+      }
+      conf.setConfString(k, v)
+    }
+    try f
+    finally {
+      keys.zip(currentValues).foreach {
+        case (key, Some(value)) => conf.setConfString(key, value)
+        case (key, None) => conf.unsetConf(key)
+      }
+    }
+  }
+
+  /**
+   * This method provides a reflection-based implementation of [[SQLConf.isStaticConfigKey]] to
+   * adapt Spark-3.1.x
+   *
+   * TODO: Once we drop support for Spark 3.1.x, we can directly call
+   * [[SQLConf.isStaticConfigKey()]].
+   */
+  private def isStaticConfigKey(key: String): Boolean = {
+    val staticConfKeys = DynFields.builder()
+      .hiddenImpl(SQLConf.getClass, "staticConfKeys")
+      .build[JSet[String]](SQLConf)
+      .get()
+    staticConfKeys.contains(key)
+  }
 }
+
+case class TestData(key: Int, value: String)
+case class TestData2(a: Int, b: Int)
diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala
index 8293123ea..1b662eadf 100644
--- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala
+++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala
@@ -27,4 +27,6 @@ object KyuubiSparkContextHelper {
   def waitListenerBus(spark: SparkSession): Unit = {
     spark.sparkContext.listenerBus.waitUntilEmpty()
   }
+
+  def dummyTaskContext(): TaskContextImpl = TaskContext.empty()
 }
diff --git a/pom.xml b/pom.xml
index b8c217a2a..aa6b7818f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -529,8 +529,8 @@
                         <artifactId>hadoop-client</artifactId>
                     </exclusion>
                     <!--
-                      The module is only used in Kyuubi Spark Extensions, so we don't care about which
-                      version of Log4j it depends on.
+                      The module is only used in Kyuubi Spark Extensions and Engine Spark SQL, so we
+                      don't care about which version of Log4j it depends on.
                      -->
                 </exclusions>
             </dependency>