You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/18 00:44:02 UTC

[spark] branch master updated: [SPARK-41165][CONNECT] Avoid hangs in the arrow collect code path

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 637cf4f4ab8 [SPARK-41165][CONNECT] Avoid hangs in the arrow collect code path
637cf4f4ab8 is described below

commit 637cf4f4ab84708e58f7265b8dea928e1964a95f
Author: Herman van Hovell <he...@databricks.com>
AuthorDate: Fri Nov 18 09:43:48 2022 +0900

    [SPARK-41165][CONNECT] Avoid hangs in the arrow collect code path
    
    ### What changes were proposed in this pull request?
    Two changes:
    1. Make sure connect's arrow result path properly deals with errors, and avoids hangs.
    2. Fix a common source of non-serializable exceptions in `SparkConnectStreamHandler`.
    
    ### Why are the changes needed?
    The current Arrow result code path for connect assumes no error can happen during execution. As a result it will hang when an error occurs.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added a UT.
    
    Closes #38681 from hvanhovell/SPARK-41165.
    
    Authored-by: Herman van Hovell <he...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../service/SparkConnectStreamHandler.scala        | 55 ++++++++++++++++++----
 .../connect/planner/SparkConnectServiceSuite.scala | 48 +++++++++++++++++++
 2 files changed, 94 insertions(+), 9 deletions(-)

diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index ec2db3efa96..a780858d55c 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.connect.service
 
 import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
 
 import com.google.protobuf.ByteString
 import io.grpc.stub.StreamObserver
@@ -27,13 +28,15 @@ import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{Request, Response}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
 import org.apache.spark.sql.execution.arrow.ArrowConverters
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ThreadUtils
 
 class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging {
-
   // The maximum batch size in bytes for a single batch of data to be returned via proto.
   private val MAX_BATCH_SIZE: Long = 4 * 1024 * 1024
 
@@ -139,14 +142,13 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       if (numPartitions > 0) {
         type Batch = (Array[Byte], Long)
 
-        val batches = rows.mapPartitionsInternal { iter =>
-          val newIter = ArrowConverters
-            .toBatchWithSchemaIterator(iter, schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)
-          newIter.map { batch: Array[Byte] => (batch, newIter.rowCountInLastBatch) }
-        }
+        val batches = rows.mapPartitionsInternal(
+          SparkConnectStreamHandler
+            .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))
 
         val signal = new Object
         val partitions = collection.mutable.Map.empty[Int, Array[Batch]]
+        var error: Throwable = null
 
         val processPartition = (iter: Iterator[Batch]) => iter.toArray
 
@@ -161,13 +163,23 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
           ()
         }
 
-        spark.sparkContext.submitJob(
+        val future = spark.sparkContext.submitJob(
           rdd = batches,
           processPartition = processPartition,
           partitions = Seq.range(0, numPartitions),
           resultHandler = resultHandler,
           resultFunc = () => ())
 
+        // Collect errors and propagate them to the main thread.
+        future.onComplete { result =>
+          result.failed.foreach { throwable =>
+            signal.synchronized {
+              error = throwable
+              signal.notify()
+            }
+          }
+        }(ThreadUtils.sameThread)
+
         // The main thread will wait until 0-th partition is available,
         // then send it to client and wait for the next partition.
         // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
@@ -178,11 +190,18 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
         while (currentPartitionId < numPartitions) {
           val partition = signal.synchronized {
             var result = partitions.remove(currentPartitionId)
-            while (result.isEmpty) {
+            while (result.isEmpty && error == null) {
               signal.wait()
               result = partitions.remove(currentPartitionId)
             }
-            result.get
+            error match {
+              case NonFatal(e) =>
+                responseObserver.onError(error)
+                logError("Error while processing query.", e)
+                return
+              case fatal: Throwable => throw fatal
+              case null => result.get
+            }
           }
 
           partition.foreach { case (bytes, count) =>
@@ -236,6 +255,24 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
   }
 }
 
+object SparkConnectStreamHandler {
+  type Batch = (Array[Byte], Long)
+
+  private[service] def rowToArrowConverter(
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      maxBatchSize: Long,
+      timeZoneId: String): Iterator[InternalRow] => Iterator[Batch] = { rows =>
+    val batches = ArrowConverters.toBatchWithSchemaIterator(
+      rows,
+      schema,
+      maxRecordsPerBatch,
+      maxBatchSize,
+      timeZoneId)
+    batches.map(b => b -> batches.rowCountInLastBatch)
+  }
+}
+
 object MetricGenerator extends AdaptiveSparkPlanHelper {
   def buildMetrics(p: SparkPlan): Response.Metrics = {
     val b = Response.Metrics.newBuilder
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 4be8d1705b9..7ff3a823fa1 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -16,9 +16,18 @@
  */
 package org.apache.spark.sql.connect.planner
 
+import scala.concurrent.Promise
+import scala.concurrent.duration._
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.SparkException
 import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.dsl.MockRemoteSession
+import org.apache.spark.sql.connect.dsl.plans._
 import org.apache.spark.sql.connect.service.SparkConnectService
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.ThreadUtils
 
 /**
  * Testing Connect Service implementation.
@@ -55,4 +64,43 @@ class SparkConnectServiceSuite extends SharedSparkSession {
           && schema.getFields(1).getType.getKindCase == proto.DataType.KindCase.STRING)
     }
   }
+
+  test("SPARK-41165: failures in the arrow collect path should not cause hangs") {
+    val instance = new SparkConnectService(false)
+
+    // Add an always crashing UDF
+    val session = SparkConnectService.getOrCreateIsolatedSession("c1").session
+    val instaKill: Long => Long = { _ =>
+      throw new Exception("Kaboom")
+    }
+    session.udf.register("insta_kill", instaKill)
+
+    val connect = new MockRemoteSession()
+    val context = proto.Request.UserContext
+      .newBuilder()
+      .setUserId("c1")
+      .build()
+    val plan = proto.Plan
+      .newBuilder()
+      .setRoot(connect.sql("select insta_kill(id) from range(10)"))
+      .build()
+    val request = proto.Request
+      .newBuilder()
+      .setPlan(plan)
+      .setUserContext(context)
+      .build()
+
+    val promise = Promise[Seq[proto.Response]]
+    instance.executePlan(
+      request,
+      new StreamObserver[proto.Response] {
+        private val responses = Seq.newBuilder[proto.Response]
+        override def onNext(v: proto.Response): Unit = responses += v
+        override def onError(throwable: Throwable): Unit = promise.failure(throwable)
+        override def onCompleted(): Unit = promise.success(responses.result())
+      })
+    intercept[SparkException] {
+      ThreadUtils.awaitResult(promise.future, 2.seconds)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org