You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/18 00:44:02 UTC
[spark] branch master updated: [SPARK-41165][CONNECT] Avoid hangs in the arrow collect code path
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 637cf4f4ab8 [SPARK-41165][CONNECT] Avoid hangs in the arrow collect code path
637cf4f4ab8 is described below
commit 637cf4f4ab84708e58f7265b8dea928e1964a95f
Author: Herman van Hovell <he...@databricks.com>
AuthorDate: Fri Nov 18 09:43:48 2022 +0900
[SPARK-41165][CONNECT] Avoid hangs in the arrow collect code path
### What changes were proposed in this pull request?
Two changes:
1. Make sure connect's arrow result path properly deals with errors, and avoids hangs.
2. Fix a common source of non-serializable exceptions in `SparkConnectStreamHandler`.
### Why are the changes needed?
The current Arrow result code path for connect assumes no error can happen during execution. As a result it will hang when an error occurs.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added a UT.
Closes #38681 from hvanhovell/SPARK-41165.
Authored-by: Herman van Hovell <he...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../service/SparkConnectStreamHandler.scala | 55 ++++++++++++++++++----
.../connect/planner/SparkConnectServiceSuite.scala | 48 +++++++++++++++++++
2 files changed, 94 insertions(+), 9 deletions(-)
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index ec2db3efa96..a780858d55c 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.connect.service
import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
import com.google.protobuf.ByteString
import io.grpc.stub.StreamObserver
@@ -27,13 +28,15 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{Request, Response}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
import org.apache.spark.sql.execution.arrow.ArrowConverters
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ThreadUtils
class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging {
-
// The maximum batch size in bytes for a single batch of data to be returned via proto.
private val MAX_BATCH_SIZE: Long = 4 * 1024 * 1024
@@ -139,14 +142,13 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
if (numPartitions > 0) {
type Batch = (Array[Byte], Long)
- val batches = rows.mapPartitionsInternal { iter =>
- val newIter = ArrowConverters
- .toBatchWithSchemaIterator(iter, schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)
- newIter.map { batch: Array[Byte] => (batch, newIter.rowCountInLastBatch) }
- }
+ val batches = rows.mapPartitionsInternal(
+ SparkConnectStreamHandler
+ .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))
val signal = new Object
val partitions = collection.mutable.Map.empty[Int, Array[Batch]]
+ var error: Throwable = null
val processPartition = (iter: Iterator[Batch]) => iter.toArray
@@ -161,13 +163,23 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
()
}
- spark.sparkContext.submitJob(
+ val future = spark.sparkContext.submitJob(
rdd = batches,
processPartition = processPartition,
partitions = Seq.range(0, numPartitions),
resultHandler = resultHandler,
resultFunc = () => ())
+ // Collect errors and propagate them to the main thread.
+ future.onComplete { result =>
+ result.failed.foreach { throwable =>
+ signal.synchronized {
+ error = throwable
+ signal.notify()
+ }
+ }
+ }(ThreadUtils.sameThread)
+
// The main thread will wait until 0-th partition is available,
// then send it to client and wait for the next partition.
// Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
@@ -178,11 +190,18 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
while (currentPartitionId < numPartitions) {
val partition = signal.synchronized {
var result = partitions.remove(currentPartitionId)
- while (result.isEmpty) {
+ while (result.isEmpty && error == null) {
signal.wait()
result = partitions.remove(currentPartitionId)
}
- result.get
+ error match {
+ case NonFatal(e) =>
+ responseObserver.onError(error)
+ logError("Error while processing query.", e)
+ return
+ case fatal: Throwable => throw fatal
+ case null => result.get
+ }
}
partition.foreach { case (bytes, count) =>
@@ -236,6 +255,24 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
}
}
+object SparkConnectStreamHandler {
+ type Batch = (Array[Byte], Long)
+
+ private[service] def rowToArrowConverter(
+ schema: StructType,
+ maxRecordsPerBatch: Int,
+ maxBatchSize: Long,
+ timeZoneId: String): Iterator[InternalRow] => Iterator[Batch] = { rows =>
+ val batches = ArrowConverters.toBatchWithSchemaIterator(
+ rows,
+ schema,
+ maxRecordsPerBatch,
+ maxBatchSize,
+ timeZoneId)
+ batches.map(b => b -> batches.rowCountInLastBatch)
+ }
+}
+
object MetricGenerator extends AdaptiveSparkPlanHelper {
def buildMetrics(p: SparkPlan): Response.Metrics = {
val b = Response.Metrics.newBuilder
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 4be8d1705b9..7ff3a823fa1 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -16,9 +16,18 @@
*/
package org.apache.spark.sql.connect.planner
+import scala.concurrent.Promise
+import scala.concurrent.duration._
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.SparkException
import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.dsl.MockRemoteSession
+import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.ThreadUtils
/**
* Testing Connect Service implementation.
@@ -55,4 +64,43 @@ class SparkConnectServiceSuite extends SharedSparkSession {
&& schema.getFields(1).getType.getKindCase == proto.DataType.KindCase.STRING)
}
}
+
+ test("SPARK-41165: failures in the arrow collect path should not cause hangs") {
+ val instance = new SparkConnectService(false)
+
+ // Add an always crashing UDF
+ val session = SparkConnectService.getOrCreateIsolatedSession("c1").session
+ val instaKill: Long => Long = { _ =>
+ throw new Exception("Kaboom")
+ }
+ session.udf.register("insta_kill", instaKill)
+
+ val connect = new MockRemoteSession()
+ val context = proto.Request.UserContext
+ .newBuilder()
+ .setUserId("c1")
+ .build()
+ val plan = proto.Plan
+ .newBuilder()
+ .setRoot(connect.sql("select insta_kill(id) from range(10)"))
+ .build()
+ val request = proto.Request
+ .newBuilder()
+ .setPlan(plan)
+ .setUserContext(context)
+ .build()
+
+ val promise = Promise[Seq[proto.Response]]
+ instance.executePlan(
+ request,
+ new StreamObserver[proto.Response] {
+ private val responses = Seq.newBuilder[proto.Response]
+ override def onNext(v: proto.Response): Unit = responses += v
+ override def onError(throwable: Throwable): Unit = promise.failure(throwable)
+ override def onCompleted(): Unit = promise.success(responses.result())
+ })
+ intercept[SparkException] {
+ ThreadUtils.awaitResult(promise.future, 2.seconds)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org