You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ul...@apache.org on 2023/04/10 01:44:06 UTC
[kyuubi] branch branch-1.7 updated: [KYUUBI #4662] [ARROW] Arrow serialization should not introduce extra shuffle for outermost limit
This is an automated email from the ASF dual-hosted git repository.
ulyssesyou pushed a commit to branch branch-1.7
in repository https://gitbox.apache.org/repos/asf/kyuubi.git
The following commit(s) were added to refs/heads/branch-1.7 by this push:
new a8fecd1e7 [KYUUBI #4662] [ARROW] Arrow serialization should not introduce extra shuffle for outermost limit
a8fecd1e7 is described below
commit a8fecd1e7cd484724db02ea759ce7c1984ed1185
Author: Fu Chen <cf...@gmail.com>
AuthorDate: Mon Apr 10 09:43:30 2023 +0800
[KYUUBI #4662] [ARROW] Arrow serialization should not introduce extra shuffle for outermost limit
### _Why are the changes needed?_
The fundamental concept is to execute a job similar to the way in which `CollectLimitExec.executeCollect()` operates.
```sql
select * from parquet.`parquet/tpcds/sf1000/catalog_sales` limit 20;
```
Before this PR:
![截屏2023-04-04 下午3 20 34](https://user-images.githubusercontent.com/8537877/229717946-87c480c6-9550-4d00-bc96-14d59d7ce9f7.png)
![截屏2023-04-04 下午3 20 54](https://user-images.githubusercontent.com/8537877/229717973-bf6da5af-74e7-422a-b9fa-8b7bebd43320.png)
After this PR:
![截屏2023-04-04 下午3 17 05](https://user-images.githubusercontent.com/8537877/229718016-6218d019-b223-4deb-b596-6f0431d33d2a.png)
![截屏2023-04-04 下午3 17 16](https://user-images.githubusercontent.com/8537877/229718046-ea07cd1f-5ffc-42ba-87d5-08085feb4b16.png)
### _How was this patch tested?_
- [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible
- [ ] Add screenshots for manual tests if appropriate
- [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request
Closes #4662 from cfmcgrady/arrow-collect-limit-exec-2.
Closes #4662
82c912ed6 [Fu Chen] close vector
130bcb141 [Fu Chen] finally close
facc13f78 [Fu Chen] exclude rule OptimizeLimitZero
370083910 [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x
6064ab961 [Fu Chen] limit = 0 test case
6d596fcce [Fu Chen] address comment
8280783c3 [Fu Chen] add `isStaticConfigKey` to adapt Spark-3.1.x
22cc70fba [Fu Chen] add ut
b72bc6fb2 [Fu Chen] add offset support to adapt Spark-3.4.x
9ffb44fb2 [Fu Chen] make toBatchIterator private
c83cf3f5e [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x
573a262ed [Fu Chen] fix
4cef20481 [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x
d70aee36b [Fu Chen] SparkPlan.session -> SparkSession.active to adapt Spark-3.1.x
e3bf84c03 [Fu Chen] refactor
81886f01c [Fu Chen] address comment
2286afc6b [Fu Chen] reflective calla AdaptiveSparkPlanExec.finalPhysicalPlan
03d074732 [Fu Chen] address comment
25e4f056c [Fu Chen] add docs
885cf2c71 [Fu Chen] infer row size by schema.defaultSize
4e7ca54df [Fu Chen] unnecessarily changes
ee5a7567a [Fu Chen] revert unnecessarily changes
6c5b1eb61 [Fu Chen] add ut
4212a8967 [Fu Chen] refactor and add ut
ed8c6928b [Fu Chen] refactor
008867122 [Fu Chen] refine
8593d856a [Fu Chen] driver slice last batch
a5849430a [Fu Chen] arrow take
Authored-by: Fu Chen <cf...@gmail.com>
Signed-off-by: ulyssesyou <ul...@apache.org>
(cherry picked from commit 1a651254cb9dec71082e9cfadd58a4dbbd976d1f)
Signed-off-by: ulyssesyou <ul...@apache.org>
---
externals/kyuubi-spark-sql-engine/pom.xml | 7 +
.../engine/spark/operation/ExecuteStatement.scala | 32 +-
.../execution/arrow/KyuubiArrowConverters.scala | 321 +++++++++++++++++++++
.../spark/sql/kyuubi/SparkDatasetHelper.scala | 160 +++++++++-
.../operation/SparkArrowbasedOperationSuite.scala | 260 ++++++++++++++++-
.../apache/spark/KyuubiSparkContextHelper.scala | 2 +
pom.xml | 4 +-
7 files changed, 753 insertions(+), 33 deletions(-)
diff --git a/externals/kyuubi-spark-sql-engine/pom.xml b/externals/kyuubi-spark-sql-engine/pom.xml
index 579dd9ca5..78821c720 100644
--- a/externals/kyuubi-spark-sql-engine/pom.xml
+++ b/externals/kyuubi-spark-sql-engine/pom.xml
@@ -65,6 +65,13 @@
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-repl_${scala.binary.version}</artifactId>
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
index fa517a8b1..015c4ba4a 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
@@ -21,10 +21,8 @@ import java.util.concurrent.RejectedExecutionException
import scala.collection.JavaConverters._
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.kyuubi.SparkDatasetHelper
+import org.apache.spark.sql.kyuubi.SparkDatasetHelper._
import org.apache.spark.sql.types._
import org.apache.kyuubi.{KyuubiSQLException, Logging}
@@ -179,34 +177,15 @@ class ArrowBasedExecuteStatement(
extends ExecuteStatement(session, statement, shouldRunAsync, queryTimeout, incrementalCollect) {
override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
- collectAsArrow(convertComplexType(resultDF)) { rdd =>
- rdd.toLocalIterator
- }
+ toArrowBatchLocalIterator(convertComplexType(resultDF))
}
override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
- collectAsArrow(convertComplexType(resultDF)) { rdd =>
- rdd.collect()
- }
+ executeCollect(convertComplexType(resultDF))
}
override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
- // this will introduce shuffle and hurt performance
- val limitedResult = resultDF.limit(maxRows)
- collectAsArrow(convertComplexType(limitedResult)) { rdd =>
- rdd.collect()
- }
- }
-
- /**
- * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
- * operation, so that we can track the arrow-based queries on the UI tab.
- */
- private def collectAsArrow[T](df: DataFrame)(action: RDD[Array[Byte]] => T): T = {
- SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
- df.queryExecution.executedPlan.resetMetrics()
- action(SparkDatasetHelper.toArrowBatchRdd(df))
- }
+ executeCollect(convertComplexType(resultDF.limit(maxRows)))
}
override protected def isArrowBasedOperation: Boolean = true
@@ -214,7 +193,6 @@ class ArrowBasedExecuteStatement(
override val resultFormat = "arrow"
private def convertComplexType(df: DataFrame): DataFrame = {
- SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
+ convertTopLevelComplexTypeToHiveString(df, timestampAsString)
}
-
}
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala
new file mode 100644
index 000000000..dd6163ec9
--- /dev/null
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.arrow
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.channels.Channels
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel}
+import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.CollectLimitExec
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.Utils
+
+object KyuubiArrowConverters extends SQLConfHelper with Logging {
+
+ type Batch = (Array[Byte], Long)
+
+ /**
+ * this method is to slice the input Arrow record batch byte array `bytes`, starting from `start`
+ * and taking `length` number of elements.
+ */
+ def slice(
+ schema: StructType,
+ timeZoneId: String,
+ bytes: Array[Byte],
+ start: Int,
+ length: Int): Array[Byte] = {
+ val in = new ByteArrayInputStream(bytes)
+ val out = new ByteArrayOutputStream(bytes.length)
+
+ var vectorSchemaRoot: VectorSchemaRoot = null
+ var slicedVectorSchemaRoot: VectorSchemaRoot = null
+
+ val sliceAllocator = ArrowUtils.rootAllocator.newChildAllocator(
+ "slice",
+ 0,
+ Long.MaxValue)
+ val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ vectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, sliceAllocator)
+ try {
+ val recordBatch = MessageSerializer.deserializeRecordBatch(
+ new ReadChannel(Channels.newChannel(in)),
+ sliceAllocator)
+ val vectorLoader = new VectorLoader(vectorSchemaRoot)
+ vectorLoader.load(recordBatch)
+ recordBatch.close()
+ slicedVectorSchemaRoot = vectorSchemaRoot.slice(start, length)
+
+ val unloader = new VectorUnloader(slicedVectorSchemaRoot)
+ val writeChannel = new WriteChannel(Channels.newChannel(out))
+ val batch = unloader.getRecordBatch()
+ MessageSerializer.serialize(writeChannel, batch)
+ batch.close()
+ out.toByteArray()
+ } finally {
+ in.close()
+ out.close()
+ if (vectorSchemaRoot != null) {
+ vectorSchemaRoot.getFieldVectors.asScala.foreach(_.close())
+ vectorSchemaRoot.close()
+ }
+ if (slicedVectorSchemaRoot != null) {
+ slicedVectorSchemaRoot.getFieldVectors.asScala.foreach(_.close())
+ slicedVectorSchemaRoot.close()
+ }
+ sliceAllocator.close()
+ }
+ }
+
+ /**
+ * Forked from `org.apache.spark.sql.execution.SparkPlan#executeTake()`, the algorithm can be
+ * summarized in the following steps:
+ * 1. If the limit specified in the CollectLimitExec object is 0, the function returns an empty
+ * array of batches.
+ * 2. Otherwise, execute the child query plan of the CollectLimitExec object to obtain an RDD of
+ * data to collect.
+ * 3. Use an iterative approach to collect data in batches until the specified limit is reached.
+ * In each iteration, it selects a subset of the partitions of the RDD to scan and tries to
+ * collect data from them.
+ * 4. For each partition subset, we use the runJob method of the Spark context to execute a
+ * closure that scans the partition data and converts it to Arrow batches.
+ * 5. Check if the collected data reaches the specified limit. If not, it selects another subset
+ * of partitions to scan and repeats the process until the limit is reached or all partitions
+ * have been scanned.
+ * 6. Return an array of all the collected Arrow batches.
+ *
+ * Note that:
+ * 1. The returned Arrow batches row count >= limit, if the input df has more than the `limit`
+ * row count
+ * 2. We don't implement the `takeFromEnd` logical
+ *
+ * @return
+ */
+ def takeAsArrowBatches(
+ collectLimitExec: CollectLimitExec,
+ maxRecordsPerBatch: Long,
+ maxEstimatedBatchSize: Long,
+ timeZoneId: String): Array[Batch] = {
+ val n = collectLimitExec.limit
+ val schema = collectLimitExec.schema
+ if (n == 0) {
+ return new Array[Batch](0)
+ } else {
+ val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
+ // TODO: refactor and reuse the code from RDD's take()
+ val childRDD = collectLimitExec.child.execute()
+ val buf = new ArrayBuffer[Batch]
+ var bufferedRowSize = 0L
+ val totalParts = childRDD.partitions.length
+ var partsScanned = 0
+ while (bufferedRowSize < n && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = limitInitialNumPartitions
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the previous iteration, multiply by
+ // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
+ // to try, but overestimate it by 50%. We also cap the estimation in the end.
+ if (buf.isEmpty) {
+ numPartsToTry = partsScanned * limitScaleUpFactor
+ } else {
+ val left = n - bufferedRowSize
+ // As left > 0, numPartsToTry is always >= 1
+ numPartsToTry = Math.ceil(1.5 * left * partsScanned / bufferedRowSize).toInt
+ numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
+ }
+ }
+
+ val partsToScan =
+ partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
+
+ // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we
+ // drop Spark-3.1.x support.
+ val sc = SparkSession.active.sparkContext
+ val res = sc.runJob(
+ childRDD,
+ (it: Iterator[InternalRow]) => {
+ val batches = toBatchIterator(
+ it,
+ schema,
+ maxRecordsPerBatch,
+ maxEstimatedBatchSize,
+ n,
+ timeZoneId)
+ batches.map(b => b -> batches.rowCountInLastBatch).toArray
+ },
+ partsToScan)
+
+ var i = 0
+ while (bufferedRowSize < n && i < res.length) {
+ var j = 0
+ val batches = res(i)
+ while (j < batches.length && n > bufferedRowSize) {
+ val batch = batches(j)
+ val (_, batchSize) = batch
+ buf += batch
+ bufferedRowSize += batchSize
+ j += 1
+ }
+ i += 1
+ }
+ partsScanned += partsToScan.size
+ }
+
+ buf.toArray
+ }
+ }
+
+ /**
+ * Spark introduced the config `spark.sql.limit.initialNumPartitions` since 3.4.0. see SPARK-40211
+ */
+ private def limitInitialNumPartitions: Int = {
+ conf.getConfString("spark.sql.limit.initialNumPartitions", "1")
+ .toInt
+ }
+
+ /**
+ * Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]],
+ * each output arrow batch contains this batch row count.
+ */
+ private def toBatchIterator(
+ rowIter: Iterator[InternalRow],
+ schema: StructType,
+ maxRecordsPerBatch: Long,
+ maxEstimatedBatchSize: Long,
+ limit: Long,
+ timeZoneId: String): ArrowBatchIterator = {
+ new ArrowBatchIterator(
+ rowIter,
+ schema,
+ maxRecordsPerBatch,
+ maxEstimatedBatchSize,
+ limit,
+ timeZoneId,
+ TaskContext.get)
+ }
+
+ /**
+ * This class ArrowBatchIterator is derived from
+ * [[org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator]],
+ * with two key differences:
+ * 1. there is no requirement to write the schema at the batch header
+ * 2. iteration halts when `rowCount` equals `limit`
+ */
+ private[sql] class ArrowBatchIterator(
+ rowIter: Iterator[InternalRow],
+ schema: StructType,
+ maxRecordsPerBatch: Long,
+ maxEstimatedBatchSize: Long,
+ limit: Long,
+ timeZoneId: String,
+ context: TaskContext)
+ extends Iterator[Array[Byte]] {
+
+ protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ private val allocator =
+ ArrowUtils.rootAllocator.newChildAllocator(
+ s"to${this.getClass.getSimpleName}",
+ 0,
+ Long.MaxValue)
+
+ private val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ protected val unloader = new VectorUnloader(root)
+ protected val arrowWriter = ArrowWriter.create(root)
+
+ Option(context).foreach {
+ _.addTaskCompletionListener[Unit] { _ =>
+ root.close()
+ allocator.close()
+ }
+ }
+
+ override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || {
+ root.close()
+ allocator.close()
+ false
+ }
+
+ var rowCountInLastBatch: Long = 0
+ var rowCount: Long = 0
+
+ override def next(): Array[Byte] = {
+ val out = new ByteArrayOutputStream()
+ val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+ rowCountInLastBatch = 0
+ var estimatedBatchSize = 0L
+ Utils.tryWithSafeFinally {
+
+ // Always write the first row.
+ while (rowIter.hasNext && (
+ // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller.
+ // If the size in bytes is positive (set properly), always write the first row.
+ rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 ||
+ // If the size in bytes of rows are 0 or negative, unlimit it.
+ estimatedBatchSize <= 0 ||
+ estimatedBatchSize < maxEstimatedBatchSize ||
+ // If the size of rows are 0 or negative, unlimit it.
+ maxRecordsPerBatch <= 0 ||
+ rowCountInLastBatch < maxRecordsPerBatch ||
+ rowCount < limit)) {
+ val row = rowIter.next()
+ arrowWriter.write(row)
+ estimatedBatchSize += (row match {
+ case ur: UnsafeRow => ur.getSizeInBytes
+ // Trying to estimate the size of the current row
+ case _: InternalRow => schema.defaultSize
+ })
+ rowCountInLastBatch += 1
+ rowCount += 1
+ }
+ arrowWriter.finish()
+ val batch = unloader.getRecordBatch()
+ MessageSerializer.serialize(writeChannel, batch)
+
+ // Always write the Ipc options at the end.
+ ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
+
+ batch.close()
+ } {
+ arrowWriter.reset()
+ }
+
+ out.toByteArray
+ }
+ }
+
+ // for testing
+ def fromBatchIterator(
+ arrowBatchIter: Iterator[Array[Byte]],
+ schema: StructType,
+ timeZoneId: String,
+ context: TaskContext): Iterator[InternalRow] = {
+ ArrowConverters.fromBatchIterator(arrowBatchIter, schema, timeZoneId, context)
+ }
+}
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala
index 1a5429373..1c8d32c48 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala
@@ -17,18 +17,75 @@
package org.apache.spark.sql.kyuubi
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.arrow.{ArrowConverters, KyuubiArrowConverters}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
+import org.apache.kyuubi.engine.spark.KyuubiSparkUtil
import org.apache.kyuubi.engine.spark.schema.RowSet
+import org.apache.kyuubi.reflection.DynMethods
+
+object SparkDatasetHelper extends Logging {
+
+ def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) {
+ executeArrowBatchCollect(df.queryExecution.executedPlan)
+ }
+
+ def executeArrowBatchCollect: SparkPlan => Array[Array[Byte]] = {
+ case adaptiveSparkPlan: AdaptiveSparkPlanExec =>
+ executeArrowBatchCollect(finalPhysicalPlan(adaptiveSparkPlan))
+ // TODO: avoid extra shuffle if `offset` > 0
+ case collectLimit: CollectLimitExec if offset(collectLimit) > 0 =>
+ logWarning("unsupported offset > 0, an extra shuffle will be introduced.")
+ toArrowBatchRdd(collectLimit).collect()
+ case collectLimit: CollectLimitExec if collectLimit.limit >= 0 =>
+ doCollectLimit(collectLimit)
+ case collectLimit: CollectLimitExec if collectLimit.limit < 0 =>
+ executeArrowBatchCollect(collectLimit.child)
+ case plan: SparkPlan =>
+ toArrowBatchRdd(plan).collect()
+ }
-object SparkDatasetHelper {
def toArrowBatchRdd[T](ds: Dataset[T]): RDD[Array[Byte]] = {
ds.toArrowBatchRdd
}
+ /**
+ * Forked from [[Dataset.toArrowBatchRdd(plan: SparkPlan)]].
+ * Convert to an RDD of serialized ArrowRecordBatches.
+ */
+ def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
+ val schemaCaptured = plan.schema
+ // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we
+ // drop Spark-3.1.x support.
+ val maxRecordsPerBatch = SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch
+ val timeZoneId = SparkSession.active.sessionState.conf.sessionLocalTimeZone
+ plan.execute().mapPartitionsInternal { iter =>
+ val context = TaskContext.get()
+ ArrowConverters.toBatchIterator(
+ iter,
+ schemaCaptured,
+ maxRecordsPerBatch,
+ timeZoneId,
+ context)
+ }
+ }
+
+ def toArrowBatchLocalIterator(df: DataFrame): Iterator[Array[Byte]] = {
+ withNewExecutionId(df) {
+ toArrowBatchRdd(df).toLocalIterator
+ }
+ }
+
def convertTopLevelComplexTypeToHiveString(
df: DataFrame,
timestampAsString: Boolean): DataFrame = {
@@ -68,11 +125,108 @@ object SparkDatasetHelper {
* Fork from Apache Spark-3.3.1 org.apache.spark.sql.catalyst.util.quoteIfNeeded to adapt to
* Spark-3.1.x
*/
- def quoteIfNeeded(part: String): String = {
+ private def quoteIfNeeded(part: String): String = {
if (part.matches("[a-zA-Z0-9_]+") && !part.matches("\\d+")) {
part
} else {
s"`${part.replace("`", "``")}`"
}
}
+
+ private lazy val maxBatchSize: Long = {
+ // respect spark connect config
+ KyuubiSparkUtil.globalSparkContext
+ .getConf
+ .getOption("spark.connect.grpc.arrow.maxBatchSize")
+ .orElse(Option("4m"))
+ .map(JavaUtils.byteStringAs(_, ByteUnit.MiB))
+ .get
+ }
+
+ private def doCollectLimit(collectLimit: CollectLimitExec): Array[Array[Byte]] = {
+ // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we
+ // drop Spark-3.1.x support.
+ val timeZoneId = SparkSession.active.sessionState.conf.sessionLocalTimeZone
+ val maxRecordsPerBatch = SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch
+
+ val batches = KyuubiArrowConverters.takeAsArrowBatches(
+ collectLimit,
+ maxRecordsPerBatch,
+ maxBatchSize,
+ timeZoneId)
+
+ // note that the number of rows in the returned arrow batches may be >= `limit`, perform
+ // the slicing operation of result
+ val result = ArrayBuffer[Array[Byte]]()
+ var i = 0
+ var rest = collectLimit.limit
+ while (i < batches.length && rest > 0) {
+ val (batch, size) = batches(i)
+ if (size <= rest) {
+ result += batch
+ // returned ArrowRecordBatch has less than `limit` row count, safety to do conversion
+ rest -= size.toInt
+ } else { // size > rest
+ result += KyuubiArrowConverters.slice(collectLimit.schema, timeZoneId, batch, 0, rest)
+ rest = 0
+ }
+ i += 1
+ }
+ result.toArray
+ }
+
+ /**
+ * This method provides a reflection-based implementation of
+ * [[AdaptiveSparkPlanExec.finalPhysicalPlan]] that enables us to adapt to the Spark runtime
+ * without patching SPARK-41914.
+ *
+ * TODO: Once we drop support for Spark 3.1.x, we can directly call
+ * [[AdaptiveSparkPlanExec.finalPhysicalPlan]].
+ */
+ def finalPhysicalPlan(adaptiveSparkPlanExec: AdaptiveSparkPlanExec): SparkPlan = {
+ withFinalPlanUpdate(adaptiveSparkPlanExec, identity)
+ }
+
+ /**
+ * A reflection-based implementation of [[AdaptiveSparkPlanExec.withFinalPlanUpdate]].
+ */
+ private def withFinalPlanUpdate[T](
+ adaptiveSparkPlanExec: AdaptiveSparkPlanExec,
+ fun: SparkPlan => T): T = {
+ val getFinalPhysicalPlan = DynMethods.builder("getFinalPhysicalPlan")
+ .hiddenImpl(adaptiveSparkPlanExec.getClass)
+ .build()
+ val plan = getFinalPhysicalPlan.invoke[SparkPlan](adaptiveSparkPlanExec)
+ val result = fun(plan)
+ val finalPlanUpdate = DynMethods.builder("finalPlanUpdate")
+ .hiddenImpl(adaptiveSparkPlanExec.getClass)
+ .build()
+ finalPlanUpdate.invoke[Unit](adaptiveSparkPlanExec)
+ result
+ }
+
+ /**
+ * offset support was add since Spark-3.4(set SPARK-28330), to ensure backward compatibility with
+ * earlier versions of Spark, this function uses reflective calls to the "offset".
+ */
+ private def offset(collectLimitExec: CollectLimitExec): Int = {
+ Option(
+ DynMethods.builder("offset")
+ .impl(collectLimitExec.getClass)
+ .orNoop()
+ .build()
+ .invoke[Int](collectLimitExec))
+ .getOrElse(0)
+ }
+
+ /**
+ * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
+ * operation, so that we can track the arrow-based queries on the UI tab.
+ */
+ private def withNewExecutionId[T](df: DataFrame)(body: => T): T = {
+ SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
+ df.queryExecution.executedPlan.resetMetrics()
+ body
+ }
+ }
}
diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
index ae6237bb5..2ef29b398 100644
--- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
+++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
@@ -18,16 +18,28 @@
package org.apache.kyuubi.engine.spark.operation
import java.sql.Statement
+import java.util.{Set => JSet}
import org.apache.spark.KyuubiSparkContextHelper
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
+import org.apache.spark.sql.{QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
-import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
+import org.apache.spark.sql.execution.exchange.Exchange
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.kyuubi.SparkDatasetHelper
import org.apache.spark.sql.util.QueryExecutionListener
+import org.apache.kyuubi.KyuubiException
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine}
import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
import org.apache.kyuubi.operation.SparkDataTypeTests
+import org.apache.kyuubi.reflection.DynFields
class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests {
@@ -138,6 +150,155 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
assert(metrics("numOutputRows").value === 1)
}
+ test("SparkDatasetHelper.executeArrowBatchCollect should return expect row count") {
+ val returnSize = Seq(
+ 0, // spark optimizer guaranty the `limit != 0`, it's just for the sanity check
+ 7, // less than one partition
+ 10, // equal to one partition
+ 13, // between one and two partitions, run two jobs
+ 20, // equal to two partitions
+ 29, // between two and three partitions
+ 1000, // all partitions
+ 1001) // more than total row count
+
+ def runAndCheck(sparkPlan: SparkPlan, expectSize: Int): Unit = {
+ val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(sparkPlan)
+ val rows = KyuubiArrowConverters.fromBatchIterator(
+ arrowBinary.iterator,
+ sparkPlan.schema,
+ "",
+ KyuubiSparkContextHelper.dummyTaskContext())
+ assert(rows.size == expectSize)
+ }
+
+ val excludedRules = Seq(
+ "org.apache.spark.sql.catalyst.optimizer.EliminateLimits",
+ "org.apache.spark.sql.catalyst.optimizer.OptimizeLimitZero",
+ "org.apache.spark.sql.execution.adaptive.AQEPropagateEmptyRelation").mkString(",")
+ withSQLConf(
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules,
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) {
+ // aqe
+ // outermost AdaptiveSparkPlanExec
+ spark.range(1000)
+ .repartitionByRange(100, col("id"))
+ .createOrReplaceTempView("t_1")
+ spark.sql("select * from t_1")
+ .foreachPartition { p: Iterator[Row] =>
+ assert(p.length == 10)
+ ()
+ }
+ returnSize.foreach { size =>
+ val df = spark.sql(s"select * from t_1 limit $size")
+ val headPlan = df.queryExecution.executedPlan.collectLeaves().head
+ if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") {
+ assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec])
+ val finalPhysicalPlan =
+ SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec])
+ assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec])
+ }
+ if (size > 1000) {
+ runAndCheck(df.queryExecution.executedPlan, 1000)
+ } else {
+ runAndCheck(df.queryExecution.executedPlan, size)
+ }
+ }
+
+ // outermost CollectLimitExec
+ spark.range(0, 1000, 1, numPartitions = 100)
+ .createOrReplaceTempView("t_2")
+ spark.sql("select * from t_2")
+ .foreachPartition { p: Iterator[Row] =>
+ assert(p.length == 10)
+ ()
+ }
+ returnSize.foreach { size =>
+ val df = spark.sql(s"select * from t_2 limit $size")
+ val plan = df.queryExecution.executedPlan
+ assert(plan.isInstanceOf[CollectLimitExec])
+ if (size > 1000) {
+ runAndCheck(df.queryExecution.executedPlan, 1000)
+ } else {
+ runAndCheck(df.queryExecution.executedPlan, size)
+ }
+ }
+ }
+ }
+
+ test("aqe should work properly") {
+
+ val s = spark
+ import s.implicits._
+
+ spark.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString))).toDF()
+ .createOrReplaceTempView("testData")
+ spark.sparkContext.parallelize(
+ TestData2(1, 1) ::
+ TestData2(1, 2) ::
+ TestData2(2, 1) ::
+ TestData2(2, 2) ::
+ TestData2(3, 1) ::
+ TestData2(3, 2) :: Nil,
+ 2).toDF()
+ .createOrReplaceTempView("testData2")
+
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT * FROM(
+ | SELECT * FROM testData join testData2 ON key = a where value = '1'
+ |) LIMIT 1
+ |""".stripMargin)
+ val smj = plan.collect { case smj: SortMergeJoinExec => smj }
+ val bhj = adaptivePlan.collect { case bhj: BroadcastHashJoinExec => bhj }
+ assert(smj.size == 1)
+ assert(bhj.size == 1)
+ }
+ }
+
+ test("result offset support") {
+ assume(SPARK_ENGINE_RUNTIME_VERSION > "3.3")
+ var numStages = 0
+ val listener = new SparkListener {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ numStages = jobStart.stageInfos.length
+ }
+ }
+ withJdbcStatement() { statement =>
+ withSparkListener(listener) {
+ withPartitionedTable("t_3") {
+ statement.executeQuery("select * from t_3 limit 10 offset 10")
+ }
+ KyuubiSparkContextHelper.waitListenerBus(spark)
+ }
+ }
+ // the extra shuffle be introduced if the `offset` > 0
+ assert(numStages == 2)
+ }
+
+ test("arrow serialization should not introduce extra shuffle for outermost limit") {
+ var numStages = 0
+ val listener = new SparkListener {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ numStages = jobStart.stageInfos.length
+ }
+ }
+ withJdbcStatement() { statement =>
+ withSparkListener(listener) {
+ withPartitionedTable("t_3") {
+ statement.executeQuery("select * from t_3 limit 1000")
+ }
+ KyuubiSparkContextHelper.waitListenerBus(spark)
+ }
+ }
+ // Should be only one stage since there is no shuffle.
+ assert(numStages == 1)
+ }
+
private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
val query =
s"""
@@ -177,4 +338,101 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
.allSessions()
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener))
}
+
+ private def withSparkListener[T](listener: SparkListener)(body: => T): T = {
+ withAllSessions(s => s.sparkContext.addSparkListener(listener))
+ try {
+ body
+ } finally {
+ withAllSessions(s => s.sparkContext.removeSparkListener(listener))
+ }
+
+ }
+
+ private def withPartitionedTable[T](viewName: String)(body: => T): T = {
+ withAllSessions { spark =>
+ spark.range(0, 1000, 1, numPartitions = 100)
+ .createOrReplaceTempView(viewName)
+ }
+ try {
+ body
+ } finally {
+ withAllSessions { spark =>
+ spark.sql(s"DROP VIEW IF EXISTS $viewName")
+ }
+ }
+ }
+
+ private def withAllSessions(op: SparkSession => Unit): Unit = {
+ SparkSQLEngine.currentEngine.get
+ .backendService
+ .sessionManager
+ .allSessions()
+ .map(_.asInstanceOf[SparkSessionImpl].spark)
+ .foreach(op(_))
+ }
+
+ private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = {
+ val dfAdaptive = spark.sql(query)
+ val planBefore = dfAdaptive.queryExecution.executedPlan
+ val result = dfAdaptive.collect()
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+ val df = spark.sql(query)
+ QueryTest.checkAnswer(df, df.collect().toSeq)
+ }
+ val planAfter = dfAdaptive.queryExecution.executedPlan
+ val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
+ val exchanges = adaptivePlan.collect {
+ case e: Exchange => e
+ }
+ assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.")
+ (dfAdaptive.queryExecution.sparkPlan, adaptivePlan)
+ }
+
+ /**
+ * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
+ * configurations.
+ */
+ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
+ val conf = SQLConf.get
+ val (keys, values) = pairs.unzip
+ val currentValues = keys.map { key =>
+ if (conf.contains(key)) {
+ Some(conf.getConfString(key))
+ } else {
+ None
+ }
+ }
+ (keys, values).zipped.foreach { (k, v) =>
+ if (isStaticConfigKey(k)) {
+ throw new KyuubiException(s"Cannot modify the value of a static config: $k")
+ }
+ conf.setConfString(k, v)
+ }
+ try f
+ finally {
+ keys.zip(currentValues).foreach {
+ case (key, Some(value)) => conf.setConfString(key, value)
+ case (key, None) => conf.unsetConf(key)
+ }
+ }
+ }
+
+ /**
+ * This method provides a reflection-based implementation of [[SQLConf.isStaticConfigKey]] to
+ * adapt Spark-3.1.x
+ *
+ * TODO: Once we drop support for Spark 3.1.x, we can directly call
+ * [[SQLConf.isStaticConfigKey()]].
+ */
+ private def isStaticConfigKey(key: String): Boolean = {
+ val staticConfKeys = DynFields.builder()
+ .hiddenImpl(SQLConf.getClass, "staticConfKeys")
+ .build[JSet[String]](SQLConf)
+ .get()
+ staticConfKeys.contains(key)
+ }
}
+
+case class TestData(key: Int, value: String)
+case class TestData2(a: Int, b: Int)
diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala
index 8293123ea..1b662eadf 100644
--- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala
+++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala
@@ -27,4 +27,6 @@ object KyuubiSparkContextHelper {
def waitListenerBus(spark: SparkSession): Unit = {
spark.sparkContext.listenerBus.waitUntilEmpty()
}
+
+ def dummyTaskContext(): TaskContextImpl = TaskContext.empty()
}
diff --git a/pom.xml b/pom.xml
index b8c217a2a..aa6b7818f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -529,8 +529,8 @@
<artifactId>hadoop-client</artifactId>
</exclusion>
<!--
- The module is only used in Kyuubi Spark Extensions, so we don't care about which
- version of Log4j it depends on.
+ The module is only used in Kyuubi Spark Extensions and Engine Spark SQL, so we
+ don't care about which version of Log4j it depends on.
-->
</exclusions>
</dependency>