You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/11/22 12:00:34 UTC

[GitHub] [spark] HyukjinKwon opened a new pull request, #38759: [SPARK-41224][SPARK-41165][SPARK-41184] Optimized Arrow-based collect implementation to stream from server to client

HyukjinKwon opened a new pull request, #38759:
URL: https://github.com/apache/spark/pull/38759

   ### What changes were proposed in this pull request?
   
   This PR proposes an optimized Arrow-based collect, that is virtually https://github.com/apache/spark/pull/38720 that implements the logics except a couple of nits.
   
   ### Why are the changes needed?
   
   To stream the Arrow batch from the server to the client side instead of waiting all the jobs to finish.
   
   ### Does this PR introduce _any_ user-facing change?
   
   No, this feature isn't released yet.
   
   ### How was this patch tested?
   
   Unittest added.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell closed pull request #38759: [SPARK-41224][SPARK-41165][SPARK-41184] Optimized Arrow-based collect implementation to stream from server to client

Posted by GitBox <gi...@apache.org>.
hvanhovell closed pull request #38759: [SPARK-41224][SPARK-41165][SPARK-41184] Optimized Arrow-based collect implementation to stream from server to client
URL: https://github.com/apache/spark/pull/38759


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38759: [SPARK-41224][SPARK-41165][SPARK-41184] Optimized Arrow-based collect implementation to stream from server to client

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38759:
URL: https://github.com/apache/spark/pull/38759#discussion_r1030973655


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -71,20 +73,80 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
       var numSent = 0
 
       if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
         val batches = rows.mapPartitionsInternal(
           SparkConnectStreamHandler
             .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))
 
-        batches.collect().foreach { case (bytes, count) =>
-          val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
-          val batch = proto.ExecutePlanResponse.ArrowBatch
-            .newBuilder()
-            .setRowCount(count)
-            .setData(ByteString.copyFrom(bytes))
-            .build()
-          response.setArrowBatch(batch)
-          responseObserver.onNext(response.build())
-          numSent += 1
+        val signal = new Object
+        val partitions = new Array[Array[Batch]](numPartitions)
+        var error: Option[Throwable] = None
+
+        // This callback is executed by the DAGScheduler thread.
+        // After fetching a partition, it inserts the partition into the Map, and then
+        // wakes up the main thread.
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          ()
+        }
+
+        val future = spark.sparkContext.submitJob(
+          rdd = batches,
+          processPartition = (iter: Iterator[Batch]) => iter.toArray,
+          partitions = Seq.range(0, numPartitions),
+          resultHandler = resultHandler,
+          resultFunc = () => ())
+
+        // Collect errors and propagate them to the main thread.
+        future.onComplete { result =>
+          result.failed.foreach { throwable =>
+            signal.synchronized {
+              error = Some(throwable)
+              signal.notify()
+            }
+          }
+        }(ThreadUtils.sameThread)
+
+        // The main thread will wait until 0-th partition is available,
+        // then send it to client and wait for the next partition.
+        // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
+        // the arrow batches in main thread to avoid DAGScheduler thread been blocked for
+        // tasks not related to scheduling. This is particularly important if there are
+        // multiple users or clients running code at the same time.
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          signal.synchronized {
+            while (partitions(currentPartitionId) == null && error.isEmpty) {
+              signal.wait()
+            }
+
+            error.foreach {
+              case NonFatal(e) =>
+                responseObserver.onError(e)
+                logError("Error while processing query.", e)
+                return
+              case other => throw other
+            }
+          }
+
+          partitions(currentPartitionId).foreach { case (bytes, count) =>

Review Comment:
   I actually think this is fine but let me address this for doubly sure.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38759: [SPARK-41224][SPARK-41165][SPARK-41184] Optimized Arrow-based collect implementation to stream from server to client

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38759:
URL: https://github.com/apache/spark/pull/38759#discussion_r1029301301


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -71,20 +73,82 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
       var numSent = 0
 
       if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
         val batches = rows.mapPartitionsInternal(
           SparkConnectStreamHandler
             .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))
 
-        batches.collect().foreach { case (bytes, count) =>
-          val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
-          val batch = proto.ExecutePlanResponse.ArrowBatch
-            .newBuilder()
-            .setRowCount(count)
-            .setData(ByteString.copyFrom(bytes))
-            .build()
-          response.setArrowBatch(batch)
-          responseObserver.onNext(response.build())
-          numSent += 1
+        val signal = new Object
+        val partitions = new Array[Array[Batch]](numPartitions - 1)
+        var error: Option[Throwable] = None
+
+        // This callback is executed by the DAGScheduler thread.
+        // After fetching a partition, it inserts the partition into the Map, and then
+        // wakes up the main thread.
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          ()
+        }
+
+        val future = spark.sparkContext.submitJob(
+          rdd = batches,
+          processPartition = (iter: Iterator[Batch]) => iter.toArray,
+          partitions = Seq.range(0, numPartitions),
+          resultHandler = resultHandler,
+          resultFunc = () => ())
+
+        // Collect errors and propagate them to the main thread.
+        future.onComplete { result =>
+          result.failed.foreach { throwable =>
+            signal.synchronized {
+              error = Some(throwable)
+              signal.notify()
+            }
+          }
+        }(ThreadUtils.sameThread)
+
+        // The main thread will wait until 0-th partition is available,
+        // then send it to client and wait for the next partition.
+        // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
+        // the arrow batches in main thread to avoid DAGScheduler thread been blocked for
+        // tasks not related to scheduling. This is particularly important if there are
+        // multiple users or clients running code at the same time.
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            val result = partitions(currentPartitionId)

Review Comment:
   You will need to update the result. Otherwise this will keep looping.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38759: [SPARK-41224][SPARK-41165][SPARK-41184] Optimized Arrow-based collect implementation to stream from server to client

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38759:
URL: https://github.com/apache/spark/pull/38759#discussion_r1030488519


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -71,20 +73,80 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
       var numSent = 0
 
       if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
         val batches = rows.mapPartitionsInternal(
           SparkConnectStreamHandler
             .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))
 
-        batches.collect().foreach { case (bytes, count) =>
-          val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
-          val batch = proto.ExecutePlanResponse.ArrowBatch
-            .newBuilder()
-            .setRowCount(count)
-            .setData(ByteString.copyFrom(bytes))
-            .build()
-          response.setArrowBatch(batch)
-          responseObserver.onNext(response.build())
-          numSent += 1
+        val signal = new Object
+        val partitions = new Array[Array[Batch]](numPartitions)
+        var error: Option[Throwable] = None
+
+        // This callback is executed by the DAGScheduler thread.
+        // After fetching a partition, it inserts the partition into the Map, and then
+        // wakes up the main thread.
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          ()
+        }
+
+        val future = spark.sparkContext.submitJob(
+          rdd = batches,
+          processPartition = (iter: Iterator[Batch]) => iter.toArray,
+          partitions = Seq.range(0, numPartitions),
+          resultHandler = resultHandler,
+          resultFunc = () => ())
+
+        // Collect errors and propagate them to the main thread.
+        future.onComplete { result =>
+          result.failed.foreach { throwable =>
+            signal.synchronized {
+              error = Some(throwable)
+              signal.notify()
+            }
+          }
+        }(ThreadUtils.sameThread)
+
+        // The main thread will wait until 0-th partition is available,
+        // then send it to client and wait for the next partition.
+        // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
+        // the arrow batches in main thread to avoid DAGScheduler thread been blocked for
+        // tasks not related to scheduling. This is particularly important if there are
+        // multiple users or clients running code at the same time.
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          signal.synchronized {
+            while (partitions(currentPartitionId) == null && error.isEmpty) {
+              signal.wait()
+            }
+
+            error.foreach {
+              case NonFatal(e) =>
+                responseObserver.onError(e)
+                logError("Error while processing query.", e)
+                return
+              case other => throw other
+            }
+          }
+
+          partitions(currentPartitionId).foreach { case (bytes, count) =>

Review Comment:
   I am not 100% if this is thread safe. The array should only be accessed while holding the lock.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38759: [SPARK-41224][SPARK-41165][SPARK-41184] Optimized Arrow-based collect implementation to stream from server to client

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38759:
URL: https://github.com/apache/spark/pull/38759#discussion_r1029301301


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -71,20 +73,82 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
       var numSent = 0
 
       if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
         val batches = rows.mapPartitionsInternal(
           SparkConnectStreamHandler
             .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))
 
-        batches.collect().foreach { case (bytes, count) =>
-          val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
-          val batch = proto.ExecutePlanResponse.ArrowBatch
-            .newBuilder()
-            .setRowCount(count)
-            .setData(ByteString.copyFrom(bytes))
-            .build()
-          response.setArrowBatch(batch)
-          responseObserver.onNext(response.build())
-          numSent += 1
+        val signal = new Object
+        val partitions = new Array[Array[Batch]](numPartitions - 1)
+        var error: Option[Throwable] = None
+
+        // This callback is executed by the DAGScheduler thread.
+        // After fetching a partition, it inserts the partition into the Map, and then
+        // wakes up the main thread.
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          ()
+        }
+
+        val future = spark.sparkContext.submitJob(
+          rdd = batches,
+          processPartition = (iter: Iterator[Batch]) => iter.toArray,
+          partitions = Seq.range(0, numPartitions),
+          resultHandler = resultHandler,
+          resultFunc = () => ())
+
+        // Collect errors and propagate them to the main thread.
+        future.onComplete { result =>
+          result.failed.foreach { throwable =>
+            signal.synchronized {
+              error = Some(throwable)
+              signal.notify()
+            }
+          }
+        }(ThreadUtils.sameThread)
+
+        // The main thread will wait until 0-th partition is available,
+        // then send it to client and wait for the next partition.
+        // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
+        // the arrow batches in main thread to avoid DAGScheduler thread been blocked for
+        // tasks not related to scheduling. This is particularly important if there are
+        // multiple users or clients running code at the same time.
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            val result = partitions(currentPartitionId)

Review Comment:
   You will need to update the result. Otherwise this will keep waiting.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38759: [SPARK-41224][SPARK-41165][SPARK-41184] Optimized Arrow-based collect implementation to stream from server to client

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38759:
URL: https://github.com/apache/spark/pull/38759#discussion_r1030488519


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -71,20 +73,80 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
       var numSent = 0
 
       if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
         val batches = rows.mapPartitionsInternal(
           SparkConnectStreamHandler
             .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))
 
-        batches.collect().foreach { case (bytes, count) =>
-          val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
-          val batch = proto.ExecutePlanResponse.ArrowBatch
-            .newBuilder()
-            .setRowCount(count)
-            .setData(ByteString.copyFrom(bytes))
-            .build()
-          response.setArrowBatch(batch)
-          responseObserver.onNext(response.build())
-          numSent += 1
+        val signal = new Object
+        val partitions = new Array[Array[Batch]](numPartitions)
+        var error: Option[Throwable] = None
+
+        // This callback is executed by the DAGScheduler thread.
+        // After fetching a partition, it inserts the partition into the Map, and then
+        // wakes up the main thread.
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          ()
+        }
+
+        val future = spark.sparkContext.submitJob(
+          rdd = batches,
+          processPartition = (iter: Iterator[Batch]) => iter.toArray,
+          partitions = Seq.range(0, numPartitions),
+          resultHandler = resultHandler,
+          resultFunc = () => ())
+
+        // Collect errors and propagate them to the main thread.
+        future.onComplete { result =>
+          result.failed.foreach { throwable =>
+            signal.synchronized {
+              error = Some(throwable)
+              signal.notify()
+            }
+          }
+        }(ThreadUtils.sameThread)
+
+        // The main thread will wait until 0-th partition is available,
+        // then send it to client and wait for the next partition.
+        // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
+        // the arrow batches in main thread to avoid DAGScheduler thread been blocked for
+        // tasks not related to scheduling. This is particularly important if there are
+        // multiple users or clients running code at the same time.
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          signal.synchronized {
+            while (partitions(currentPartitionId) == null && error.isEmpty) {
+              signal.wait()
+            }
+
+            error.foreach {
+              case NonFatal(e) =>
+                responseObserver.onError(e)
+                logError("Error while processing query.", e)
+                return
+              case other => throw other
+            }
+          }
+
+          partitions(currentPartitionId).foreach { case (bytes, count) =>

Review Comment:
   I am not 100% sure if this is thread safe. The array should only be accessed while holding the lock.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org