You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/11/01 11:51:56 UTC

[GitHub] [spark] zhengruifeng opened a new pull request, #38468: [WIP][CONNECT][PYTHON] Arrow-based collect

zhengruifeng opened a new pull request, #38468:
URL: https://github.com/apache/spark/pull/38468

   <!--
   Thanks for sending a pull request!  Here are some tips for you:
     1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html
     2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html
     3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'.
     4. Be sure to keep the PR description updated to reflect all changes.
     5. Please write your PR title to summarize what this PR proposes.
     6. If possible, provide a concise example to reproduce the issue for a faster review.
     7. If you want to add a new configuration, please read the guideline first for naming configurations in
        'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
     8. If you want to add or modify an error type or message, please read the guideline first in
        'core/src/main/resources/error/README.md'.
   -->
   
   ### What changes were proposed in this pull request?
   <!--
   Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. 
   If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below.
     1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers.
     2. If you fix some SQL features, you can provide some references of other DBMSes.
     3. If there is design documentation, please add the link.
     4. If there is a discussion in the mailing list, please add the link.
   -->
   
   
   ### Why are the changes needed?
   <!--
   Please clarify why the changes are needed. For instance,
     1. If you propose a new API, clarify the use case for a new API.
     2. If you fix a bug, you can clarify why it is a bug.
   -->
   
   
   ### Does this PR introduce _any_ user-facing change?
   <!--
   Note that it means *any* user-facing change including all aspects such as the documentation fix.
   If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible.
   If possible, please also clarify if this is a user-facing change compared to the released Spark versions or within the unreleased branches such as master.
   If no, write 'No'.
   -->
   
   
   ### How was this patch tested?
   <!--
   If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
   If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
   If tests were not added, please describe why they were not added and/or why it was difficult to add.
   If benchmark tests were added, please run the benchmarks in GitHub Actions for the consistent environment, and the instructions could accord to: https://spark.apache.org/developer-tools.html#github-workflow-benchmarks.
   -->
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018681767


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   Is it to wait the partitions to be fetched in order? I think we can just fetch all and send the first if that arrives. To optimize this, I think we should eventually do the reordering in some way to match with PySpark's implementation. Even we should deduplicate the codes ideally.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] amaliujia commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
amaliujia commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018699060


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   Sorry I am still bit confused on how does this maintain the ordering of partitions. I assume other people when they reading this code, they might be confused as well.
   
   Is it possible to have some developer comment here to explain the algorithm?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] grundprinzip commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
grundprinzip commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018699081


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records

Review Comment:
   I don't think there is a throughput limit in GRPC itself. 
   
   The reason for the batching is that protobuf is not suited for this. Embedding large binary objects might require the reader to materialize them in memory. 
   
   Fixing this is an optimization for later. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] amaliujia commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
amaliujia commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018700003


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()
+            }
+            val partition = partitions(currentPartitionId)
+            partitions(currentPartitionId) = null
+            partition
+          }
+
+          // only send non-empty partitions
+          if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {
+            partition.foreach { case (bytes, count, size) =>
+              val response = proto.Response.newBuilder().setClientId(clientId)
+              val batch = proto.Response.ArrowBatch
+                .newBuilder()
+                .setRowCount(count)
+                .setUncompressedBytes(size)
+                .setCompressedBytes(bytes.length)
+                .setData(ByteString.copyFrom(bytes))
+                .build()
+              response.setArrowBatch(batch)
+              responseObserver.onNext(response.build())
+            }
+            numSent += 1
+          }
+
+          currentPartitionId += 1
+        }
+      }
+
+      // make sure at least 1 batch will be sent
+      if (numSent == 0) {

Review Comment:
   +1. 
   
   with this we at least can get rid of the `Optional[Pandas]` from the API interface.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018690469


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit

Review Comment:
   got it, will update



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] cloud-fan commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
cloud-fan commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019152492


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,92 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(

Review Comment:
   I see, we probably need a better naming, as it's hard to tell the difference between `toBatchIterator` and `toArrowBatchIterator`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] amaliujia commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
amaliujia commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1015821355


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {
+          if (taskResult.exists(_._1.nonEmpty)) {
+            // only send non-empty partitions
+            val task = pool.submit(new Runnable {
+              override def run(): Unit = {
+                var batchId = partitionId.toLong << 33
+                taskResult.foreach { case (bytes, count, size) =>
+                  val response = proto.Response.newBuilder().setClientId(clientId)
+                  val batch = proto.Response.ArrowBatch
+                    .newBuilder()
+                    .setBatchId(batchId)
+                    .setRowCount(count)
+                    .setUncompressedBytes(size)
+                    .setCompressedBytes(bytes.length)
+                    .setData(ByteString.copyFrom(bytes))
+                    .build()
+                  response.setArrowBatch(batch)
+                  responseObserver.onNext(response.build())

Review Comment:
   I feel like this comes with a bigger problem to define the client side incremental collect protocol first. Client side consumes data but does not consume them all at once. How client incremental collect? Then combined with that how server side control the rate.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1017736830


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +127,99 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val queue = collection.mutable.Queue.empty[(Int, Array[(Array[Byte], Long, Long)])]

Review Comment:
   I don't think we should make the client responsible for ordering the results. This will be a burden for all clients.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013700526


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,65 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String,
+      context: TaskContext): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    if (context != null) { // for test at driver
+      context.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {

Review Comment:
   ok, let me check it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013555606


##########
python/pyspark/sql/connect/client.py:
##########
@@ -182,6 +191,10 @@ def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]:
         req = pb2.Request()
         req.user_context.user_id = self._user_id
         req.plan.CopyFrom(plan)
+        if self.has_arrow:
+            req.preferred_result_type = pb2.Request.ArrowBatch
+        else:
+            req.preferred_result_type = pb2.Request.JSONBatch

Review Comment:
   We could remove this and just let it throw an exception. I think we can just make the Arrow as a hard requirement. Since this is a new code not released yet, let's go with a stricter approach. cc @grundprinzip FYI.
   
   We can remove this in a separate PR too.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018688345


##########
python/pyspark/sql/tests/connect/test_connect_basic.py:
##########
@@ -197,6 +197,17 @@ def test_range(self):
             .equals(self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas())
         )
 
+    def test_empty_dataset(self):
+        self.assertTrue(

Review Comment:
   Let's add a comment with a JIIRA like `# SPARK-XXXX: ...` (https://spark.apache.org/contributing.html)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019036963


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {
+          if (taskResult.exists(_._1.nonEmpty)) {
+            // only send non-empty partitions
+            val task = pool.submit(new Runnable {
+              override def run(): Unit = {
+                var batchId = partitionId.toLong << 33
+                taskResult.foreach { case (bytes, count, size) =>
+                  val response = proto.Response.newBuilder().setClientId(clientId)
+                  val batch = proto.Response.ArrowBatch
+                    .newBuilder()
+                    .setBatchId(batchId)
+                    .setRowCount(count)
+                    .setUncompressedBytes(size)
+                    .setCompressedBytes(bytes.length)
+                    .setData(ByteString.copyFrom(bytes))
+                    .build()
+                  response.setArrowBatch(batch)
+                  responseObserver.onNext(response.build())

Review Comment:
   You could do that yes. If we get the format right we could just serve the block directly to the client (to avoid deserialization).
   
   Alternatively you could also write the data to some persistent storage on the executors, and just pass the file paths to the client. That will be the most efficient from the POV of the driver. This requires you have some garbage collection in place to clean-up results.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019099531


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   > If the first partition arrives last, the whole dataset stays in the driver's memory, right?
   
   yes, but at least it's not worse than existing `collect` which always keep whole dataset in memory.
   
   receiving the partitions by order may make it easier to consume in the client, if ordering matters.
   
   I think we will optimize it further, it is just an initial implementation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019093054


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = collection.mutable.Map.empty[Int, Array[Batch]]

Review Comment:
   just change from array to map ... see https://github.com/apache/spark/pull/38468#discussion_r1018938395



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] grundprinzip commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
grundprinzip commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018799227


##########
connector/connect/src/main/protobuf/spark/connect/base.proto:
##########
@@ -83,7 +83,6 @@ message Response {
     int64 uncompressed_bytes = 2;

Review Comment:
   I think a pre-mature optimization in my head wanted to put some streaming compression on it. We can probably just get rid of them for now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016363328


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {

Review Comment:
   i guess the reordering in the client side is cheap, it only sort # batch  integers



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1015733832


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {
+          if (taskResult.exists(_._1.nonEmpty)) {
+            // only send non-empty partitions
+            val task = pool.submit(new Runnable {
+              override def run(): Unit = {
+                var batchId = partitionId.toLong << 33
+                taskResult.foreach { case (bytes, count, size) =>
+                  val response = proto.Response.newBuilder().setClientId(clientId)
+                  val batch = proto.Response.ArrowBatch
+                    .newBuilder()
+                    .setBatchId(batchId)
+                    .setRowCount(count)
+                    .setUncompressedBytes(size)
+                    .setCompressedBytes(bytes.length)
+                    .setData(ByteString.copyFrom(bytes))
+                    .build()
+                  response.setArrowBatch(batch)
+                  responseObserver.onNext(response.build())

Review Comment:
   If I understand correctly, the `TaskResultGetter` runs in a separate thread pool, and there is no back pressure mechanism, the current approach still may cause memory pressure on the driver if the client consumes results slowly.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1015387801


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")

Review Comment:
   Why not use the main thread for this? It is not entirely cheap to spin-up s thread pool, and you have the main thread twiddling its thumbs anyway. It also makes completing the stream simpler.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014280302


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,65 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String,
+      context: TaskContext): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    if (context != null) { // for test at driver
+      context.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {
+        val out = new ByteArrayOutputStream()
+        val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+        var rowCount = 0L
+        var estimatedSize = 0L
+        Utils.tryWithSafeFinally {
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+            estimatedSize += SizeEstimator.estimate(row)
+          }
+          arrowWriter.finish()
+          val batch = unloader.getRecordBatch()
+
+          MessageSerializer.serialize(writeChannel, arrowSchema)

Review Comment:
   sgtm



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] amaliujia commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
amaliujia commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018699466


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()
+            }
+            val partition = partitions(currentPartitionId)
+            partitions(currentPartitionId) = null
+            partition
+          }
+
+          // only send non-empty partitions
+          if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {

Review Comment:
   Just for my education, empty partition is meaning less for clients (even there is no data, but know a range of data set is empty, etc.)?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] amaliujia commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
amaliujia commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018697387


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records

Review Comment:
   interesting. I only know that gRPC has a hard limits of 2GB/s transfer rate. Never know it might not favor over large messages.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019048009


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {
+          if (taskResult.exists(_._1.nonEmpty)) {
+            // only send non-empty partitions
+            val task = pool.submit(new Runnable {
+              override def run(): Unit = {
+                var batchId = partitionId.toLong << 33
+                taskResult.foreach { case (bytes, count, size) =>
+                  val response = proto.Response.newBuilder().setClientId(clientId)
+                  val batch = proto.Response.ArrowBatch
+                    .newBuilder()
+                    .setBatchId(batchId)
+                    .setRowCount(count)
+                    .setUncompressedBytes(size)
+                    .setCompressedBytes(bytes.length)
+                    .setData(ByteString.copyFrom(bytes))
+                    .build()
+                  response.setArrowBatch(batch)
+                  responseObserver.onNext(response.build())

Review Comment:
   To be clear, my suggestion is to defer the deserialization and block releasing to `JobWaiter#taskSucceeded` phase.
   
   The second way you suggested has many performance benefits but requires the client to communicate with the external storage service, it brings other burdens for the client, e.g. result cleanup(you mentioned), s3/hdfs client libs is quite large, network policy, auth.
   
   The `IndirectTaskResult` way leverages the existing code and spark build-in block mechanism to transfer data, we can benefit w/ a little code modification, and we don't need to worry about result cleanup.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] cloud-fan commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
cloud-fan commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019090389


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = collection.mutable.Map.empty[Int, Array[Batch]]

Review Comment:
   do we really need a map? We know the number of partitions and we can just create an array.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019820867


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = collection.mutable.Map.empty[Int, Array[Batch]]
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        // This callback is executed by the DAGScheduler thread.
+        // After fetching a partition, it inserts the partition into the Map, and then
+        // wakes up the main thread.
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          ()
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)

Review Comment:
   Seems like it had to be (async) `submitJob` instead of (sync) `runJob` (https://github.com/apache/spark/pull/38468#discussion_r1013184548). In fact, I figured out a simpler way to avoid synchronization. PTAL https://github.com/apache/spark/pull/38613



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018837877


##########
connector/connect/src/main/protobuf/spark/connect/base.proto:
##########
@@ -83,7 +83,6 @@ message Response {
     int64 uncompressed_bytes = 2;

Review Comment:
   @zhengruifeng let;s get rid of this, then we don;t need to estimate the size, and would be easier to deduplicate the codes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019058105


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   In reduce phase, the task fetches map data of each partition in random order, w/o local sort, user still sees indeterminate data even the driver returns data by partition id.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] cloud-fan commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
cloud-fan commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019093136


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = collection.mutable.Map.empty[Int, Array[Batch]]

Review Comment:
   We should probably add some comments at the beginning to explain the overall workflow.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] cloud-fan commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
cloud-fan commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019094064


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = collection.mutable.Map.empty[Int, Array[Batch]]

Review Comment:
   can we apply the same idea to JSON batches?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019032161


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   That is an optimization we can go for after merging this. My only problem with that is that it might scare users, because collecting a dataframe multiple times might have a results that look differently.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013184548


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +131,36 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val schema = dataframe.schema
+    val maxRecordsPerBatch = dataframe.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = dataframe.sparkSession.sessionState.conf.sessionLocalTimeZone
+
+    val batches = dataframe.queryExecution.executedPlan
+      .execute()
+      .mapPartitionsInternal { iter =>
+        ArrowConverters
+          .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId, TaskContext.get)
+      }
+
+    batches.toLocalIterator.foreach { case (bytes, count, size) =>

Review Comment:
   Let's also avoid using local iterator. That will make things unnecessarily slow because it limits parallelism. I was thinking doing the following:
   ```scala
   def collectArrow(callback: Array[Array[Byte]] => ()): Unit = {
       withAction("collectArrow", queryExecution) { plan =>
         val rdd = toArrowBatchRdd(plan)
         val signal = new Object
         val numPartitions = rdd.getNumPartitions
         val availablePartitions = mutable.Map.empty[Int, Array[Array[Byte]]]
         def onNewPartition(partitionId: Int, batches: Array[Array[Byte]]): Unit = {
           signal.synchronized {
             availablePartitions(partitionId) = batches
             signal.notify()
           }
         }
         sparkSession.sparkContext.submitJob(
           rdd = rdd,
           processPartition = (i: Iterator[Array[Byte]]) => i.toArray,
           partitions = 0 until numPartitions,
           resultHandler = onNewPartition,
           resultFunc = () => ())
   
         var currentPartitionId = 0
         while (currentPartitionId < numPartitions) {
           val batches = signal.synchronized {
             var result = availablePartitions.remove(currentPartitionId)
             while (result.isEmpty) {
               signal.wait()
               result = availablePartitions.remove(currentPartitionId)
             }
             result.get
           }
           callback(batches)
           currentPartitionId += 1
         }
       }
     }
   ```
   This was hacked together in Dataset, but IMO it is better to do this somewhere in here. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014279677


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    val rows = dataframe.queryExecution.executedPlan.execute()
+    var numBatches = 0L
+
+    if (rows.getNumPartitions > 0) {
+      val batches = rows.mapPartitionsInternal { iter =>
+        ArrowConverters
+          .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+      }
+
+      val obj = new Object
+
+      val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+      val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) =>
+        obj.synchronized {
+          var batchId = partitionId.toLong << 33
+          taskResult.foreach { case (bytes, count, size) =>
+            val response = proto.Response.newBuilder().setClientId(clientId)
+            val batch = proto.Response.ArrowBatch
+              .newBuilder()
+              .setBatchId(batchId)
+              .setRowCount(count)
+              .setUncompressedBytes(size)
+              .setCompressedBytes(bytes.length)
+              .setData(ByteString.copyFrom(bytes))
+              .build()
+            response.setArrowBatch(batch)
+            responseObserver.onNext(response.build())

Review Comment:
   This callback is currently executed by the DAGScheduler thread, and if this expensive no job/stage can be scheduler during this call. We really should move this to a separate thread.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014269647


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    val rows = dataframe.queryExecution.executedPlan.execute()
+    var numBatches = 0L
+
+    if (rows.getNumPartitions > 0) {
+      val batches = rows.mapPartitionsInternal { iter =>
+        ArrowConverters
+          .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+      }
+
+      val obj = new Object
+
+      val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray

Review Comment:
   This breaks sorted results. A higher partition can complete earlier than lower ones thus breaking the order. That is why I the snippet I posted buffered the partitions in the handler, while the main thread scanned over them 1 by 1.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] amaliujia commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
amaliujia commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1015821790


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {

Review Comment:
   Another question is do we always need to maintain the partition ordering?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016363328


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {

Review Comment:
   i guess the reordering in the client side is cheap, it only sort # batch  ints



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018689309


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit

Review Comment:
   You can specify `Unit` by `()`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [WIP][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1012441577


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +121,38 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
+
+    val schema = dataframe.schema
+    val maxRecordsPerBatch = dataframe.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = dataframe.sparkSession.sessionState.conf.sessionLocalTimeZone
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val schemaBuffer = MessageSerializer.serializeMetadata(arrowSchema, IpcOption.DEFAULT)
+
+    val rows = dataframe.queryExecution.executedPlan.execute().map(_.copy()).collect()

Review Comment:
   1, agree, will update
   2, i think we can use `toLocalIterator`, but it may trigger multi jobs https://github.com/apache/spark/pull/38300#discussion_r1001301224



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016358065


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")

Review Comment:
   > Why not use the main thread for this?
   ok will follow https://github.com/apache/spark/pull/38468#discussion_r1013184548



##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")

Review Comment:
   > Why not use the main thread for this?
   
   ok will follow https://github.com/apache/spark/pull/38468#discussion_r1013184548



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014600233


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {
+          if (taskResult.exists(_._1.nonEmpty)) {
+            // only send non-empty partitions
+            val task = pool.submit(new Runnable {
+              override def run(): Unit = {
+                var batchId = partitionId.toLong << 33

Review Comment:
   generate batch ids in the same way of `monotonically_increasing_id`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] amaliujia commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
amaliujia commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018700467


##########
python/pyspark/sql/tests/connect/test_connect_basic.py:
##########
@@ -197,6 +197,17 @@ def test_range(self):
             .equals(self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas())
         )
 
+    def test_empty_dataset(self):
+        self.assertTrue(

Review Comment:
   +1. Maintaining contribution style in this module.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018794542


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()
+            }
+            val partition = partitions(currentPartitionId)
+            partitions(currentPartitionId) = null
+            partition
+          }
+
+          // only send non-empty partitions
+          if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {
+            partition.foreach { case (bytes, count, size) =>
+              val response = proto.Response.newBuilder().setClientId(clientId)
+              val batch = proto.Response.ArrowBatch
+                .newBuilder()
+                .setRowCount(count)
+                .setUncompressedBytes(size)
+                .setCompressedBytes(bytes.length)
+                .setData(ByteString.copyFrom(bytes))
+                .build()
+              response.setArrowBatch(batch)
+              responseObserver.onNext(response.build())
+            }
+            numSent += 1
+          }
+
+          currentPartitionId += 1
+        }
+      }
+
+      // make sure at least 1 batch will be sent
+      if (numSent == 0) {

Review Comment:
   `Optional[Pandas]` maybe still needed since 0 json batch maybe returned



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] grundprinzip commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
grundprinzip commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018805569


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()
+            }
+            val partition = partitions(currentPartitionId)
+            partitions(currentPartitionId) = null
+            partition
+          }
+
+          // only send non-empty partitions

Review Comment:
   ```suggestion
             // Only send non-empty partitions.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] grundprinzip commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
grundprinzip commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018810515


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()
+            }
+            val partition = partitions(currentPartitionId)
+            partitions(currentPartitionId) = null
+            partition
+          }
+
+          // only send non-empty partitions
+          if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {
+            partition.foreach { case (bytes, count, size) =>
+              val response = proto.Response.newBuilder().setClientId(clientId)
+              val batch = proto.Response.ArrowBatch
+                .newBuilder()
+                .setRowCount(count)
+                .setUncompressedBytes(size)
+                .setCompressedBytes(bytes.length)
+                .setData(ByteString.copyFrom(bytes))
+                .build()
+              response.setArrowBatch(batch)
+              responseObserver.onNext(response.build())
+            }
+            numSent += 1

Review Comment:
   Is there a way we could track this as a spark metric for the query? Fine to do in a follow up if we create a jira



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018938395


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)

Review Comment:
   You could use a map here. That will be a bit more friendly on memory when you have a task with many tasks. The insight is that it is unlikely that you have to buffer much since spark executes partitions in order (from low to high).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019007963


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   And another question is, if there is no `Sort`, do we need keep the partition order?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on PR #38468:
URL: https://github.com/apache/spark/pull/38468#issuecomment-1311310134

   Made another PR to refactor and deduplicate the Arrow codes PTAL: https://github.com/apache/spark/pull/38618


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on PR #38468:
URL: https://github.com/apache/spark/pull/38468#issuecomment-1311113900

   merged into master, will have a follow up PR to update `toArrowBatchIterator `


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1015381686


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {

Review Comment:
   I think we should discuss client side ordering of results a bit more. My main problem with this is is that it is unexpected behavior, undocumented behavior (hint hint), and it puts more burden on the client (it needs to reorder results and buffer them).
   
   cc @HyukjinKwon @amaliujia @grundprinzip 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016118342


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {
+          if (taskResult.exists(_._1.nonEmpty)) {
+            // only send non-empty partitions
+            val task = pool.submit(new Runnable {
+              override def run(): Unit = {
+                var batchId = partitionId.toLong << 33
+                taskResult.foreach { case (bytes, count, size) =>
+                  val response = proto.Response.newBuilder().setClientId(clientId)
+                  val batch = proto.Response.ArrowBatch
+                    .newBuilder()
+                    .setBatchId(batchId)
+                    .setRowCount(count)
+                    .setUncompressedBytes(size)
+                    .setCompressedBytes(bytes.length)
+                    .setData(ByteString.copyFrom(bytes))
+                    .build()
+                  response.setArrowBatch(batch)
+                  responseObserver.onNext(response.build())

Review Comment:
   Things goes worse if the we decide to keep strict partition ordering. 
   
   Actually, Spark implemented `IndirectTaskResult` to support transfer task result from executor to driver through block manager, but currently it eagerly fetches and deserializes blocks to `DirectTaskResult` on `task-result-getter` thread pool. What if defer it to `JobWaiter#taskSucceeded`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013560372


##########
python/pyspark/sql/connect/client.py:
##########
@@ -182,6 +191,10 @@ def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]:
         req = pb2.Request()
         req.user_context.user_id = self._user_id
         req.plan.CopyFrom(plan)
+        if self.has_arrow:
+            req.preferred_result_type = pb2.Request.ArrowBatch
+        else:
+            req.preferred_result_type = pb2.Request.JSONBatch

Review Comment:
   i notice that pyspark checks whether schema is supported https://github.com/apache/spark/blob/master/python/pyspark/sql/pandas/conversion.py#L102 ,
   i move this check to the server side since it needs schema.
   
   maybe we can keep it as a fallback if json could support more data types (not sure about this)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013528536


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +131,36 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val schema = dataframe.schema
+    val maxRecordsPerBatch = dataframe.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = dataframe.sparkSession.sessionState.conf.sessionLocalTimeZone
+
+    val batches = dataframe.queryExecution.executedPlan
+      .execute()
+      .mapPartitionsInternal { iter =>
+        ArrowConverters
+          .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId, TaskContext.get)
+      }
+
+    batches.toLocalIterator.foreach { case (bytes, count, size) =>

Review Comment:
   Yeah we should avoid `toLocalIterator`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013560372


##########
python/pyspark/sql/connect/client.py:
##########
@@ -182,6 +191,10 @@ def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]:
         req = pb2.Request()
         req.user_context.user_id = self._user_id
         req.plan.CopyFrom(plan)
+        if self.has_arrow:
+            req.preferred_result_type = pb2.Request.ArrowBatch
+        else:
+            req.preferred_result_type = pb2.Request.JSONBatch

Review Comment:
   i notice that pyspark checks whether schema is supported https://github.com/apache/spark/blob/master/python/pyspark/sql/pandas/conversion.py#L102 ,
   i move this check to the server side since it needs schema.
   
   ~~maybe we can keep it as a fallback if json could support more data types (not sure about this)~~
   
   yes, we always prefer arrow, but may get json batches from server if schema is not supported



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [WIP][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1012742215


##########
python/pyspark/sql/connect/client.py:
##########
@@ -251,6 +263,13 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra
 
         if len(result_dfs) > 0:
             df = pd.concat(result_dfs)
+
+            # pd.concat generates non-consecutive index like:
+            #   Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64')
+            # set it to RangeIndex to be consistent with pyspark
+            n = len(df)
+            df = df.set_index(pd.RangeIndex(start=0, stop=n, step=1))

Review Comment:
   make this change , otherwise some tests will fail
   
   those tests only generate single json batch, so works with json 



##########
python/pyspark/sql/connect/client.py:
##########
@@ -251,6 +263,13 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra
 
         if len(result_dfs) > 0:
             df = pd.concat(result_dfs)
+
+            # pd.concat generates non-consecutive index like:
+            #   Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64')
+            # set it to RangeIndex to be consistent with pyspark
+            n = len(df)
+            df = df.set_index(pd.RangeIndex(start=0, stop=n, step=1))

Review Comment:
   make this change , otherwise some tests will fail
   
   those tests only generate single json batch, so worked with json 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [WIP][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1012428719


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +121,38 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
+
+    val schema = dataframe.schema
+    val maxRecordsPerBatch = dataframe.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = dataframe.sparkSession.sessionState.conf.sessionLocalTimeZone
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val schemaBuffer = MessageSerializer.serializeMetadata(arrowSchema, IpcOption.DEFAULT)
+
+    val rows = dataframe.queryExecution.executedPlan.execute().map(_.copy()).collect()

Review Comment:
   - Wouldn't it be better to do the heavy lifting on the executors? IMO it is better to convert to arrow directly. `Dataset.toArrowBatchRdd` seems to be a good start.
   - It would also be nice if we can avoid materializing the entire result on the driver. We should be able to forward the batches for a partition immediately when we receive them.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016362177


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {

Review Comment:
   i guess the reordering in the client side is cheap, it only sort # batch  ints



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] cloud-fan commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
cloud-fan commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019092340


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = collection.mutable.Map.empty[Int, Array[Batch]]

Review Comment:
   nvm, this is kind of async collect.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018681437


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -48,19 +51,25 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
     }
   }
 
-  def handlePlan(session: SparkSession, request: proto.Request): Unit = {
+  def handlePlan(session: SparkSession, request: Request): Unit = {
     // Extract the plan from the request and convert it to a logical plan
     val planner = new SparkConnectPlanner(request.getPlan.getRoot, session)
-    val rows =
-      Dataset.ofRows(session, planner.transform())
-    processRows(request.getClientId, rows)
+    val dataframe = Dataset.ofRows(session, planner.transform())
+    // check whether all data types are supported
+    if (Try {
+        ArrowUtils.toArrowSchema(dataframe.schema, session.sessionState.conf.sessionLocalTimeZone)
+      }.isSuccess) {
+      processRowsAsArrowBatches(request.getClientId, dataframe)
+    } else {
+      processRowsAsJsonBatches(request.getClientId, dataframe)

Review Comment:
   nice will update



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018684111


##########
python/pyspark/sql/connect/client.py:
##########
@@ -400,6 +400,14 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra
 
         if len(result_dfs) > 0:
             df = pd.concat(result_dfs)
+            del result_dfs

Review Comment:
   just want to release the buffer asap, maybe not needed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019007963


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   And another question is, if there is no `Sort`, do we need reserve the partition order?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] amaliujia commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
amaliujia commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1015819772


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {

Review Comment:
   @cloud-fan and I had a discussion on this. I remember our initial thought was to maintain the partition ordering on the server side. For example, there is `def foreachPartition` API already. The question was whether we can foreach partitions in the partitioning order.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018685534


##########
connector/connect/src/main/protobuf/spark/connect/base.proto:
##########
@@ -83,7 +83,6 @@ message Response {
     int64 uncompressed_bytes = 2;

Review Comment:
   @grundprinzip do you know why we need this?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] amaliujia commented on pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
amaliujia commented on PR #38468:
URL: https://github.com/apache/spark/pull/38468#issuecomment-1309851369

   Thanks for the most recent updates. I find it will help other problems. For example now we at least send one partition with schema even all partitions are empty. By doing so, clients won't need to worry `None` check.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018741578


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()
+            }
+            val partition = partitions(currentPartitionId)
+            partitions(currentPartitionId) = null
+            partition
+          }
+
+          // only send non-empty partitions
+          if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {

Review Comment:
   yes, at least for `collect` empty partition is meaningless
   
   maybe meaningful for some partitioning-aware operations like `RDD.zipPartitions`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018741978


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,97 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    Option(TaskContext.get).foreach {
+      _.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {
+        val out = new ByteArrayOutputStream()
+        val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+        var rowCount = 0L
+        var estimatedSize = SizeEstimator.estimate(arrowSchema) +
+          SizeEstimator.estimate(IpcOption.DEFAULT)
+        Utils.tryWithSafeFinally {
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+            estimatedSize += SizeEstimator.estimate(row)
+          }
+          arrowWriter.finish()
+          val batch = unloader.getRecordBatch()
+
+          MessageSerializer.serialize(writeChannel, arrowSchema)
+          MessageSerializer.serialize(writeChannel, batch)
+          ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
+
+          batch.close()
+        } {
+          arrowWriter.reset()
+        }
+
+        (out.toByteArray, rowCount, estimatedSize)
+      }
+    }
+  }
+
+  private[sql] def createEmptyArrowBatch(

Review Comment:
   let me take a look



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018944392


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()
+            }
+            val partition = partitions(currentPartitionId)
+            partitions(currentPartitionId) = null
+            partition
+          }
+
+          // only send non-empty partitions
+          if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {

Review Comment:
   Different questions. Why not just iterator over the partitions, and filter out the non-empty batches? That should be the same and it saves you from an unneeded if.



##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()
+            }
+            val partition = partitions(currentPartitionId)
+            partitions(currentPartitionId) = null
+            partition
+          }
+
+          // only send non-empty partitions
+          if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {

Review Comment:
   Different questions. Why not just iterate over the partitions, and filter out the non-empty batches? That should be the same and it saves you from an unneeded if.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018969469


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   If the first partition arrives last, the whole dataset stays in the driver's memory, right?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] cloud-fan commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
cloud-fan commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019104867


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,92 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(

Review Comment:
   can we simply change `toBatchIterator` to return row count as well? The perf overhead is very small but the maintenance overhead is much larger if we have 2 similar methods.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013189541


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,65 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String,
+      context: TaskContext): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    if (context != null) { // for test at driver
+      context.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {
+        val out = new ByteArrayOutputStream()
+        val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+        var rowCount = 0L
+        var estimatedSize = 0L
+        Utils.tryWithSafeFinally {
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+            estimatedSize += SizeEstimator.estimate(row)
+          }
+          arrowWriter.finish()
+          val batch = unloader.getRecordBatch()
+
+          MessageSerializer.serialize(writeChannel, arrowSchema)

Review Comment:
   Instead of turning each written result into a separate IPC stream, why not make the driver send the schema, and then stream back the record batches? I am not 100% how hard it would be to reassemble this on the python side, on the scala side it would be fairly straightforward.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013186031


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +131,36 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val schema = dataframe.schema
+    val maxRecordsPerBatch = dataframe.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = dataframe.sparkSession.sessionState.conf.sessionLocalTimeZone
+
+    val batches = dataframe.queryExecution.executedPlan
+      .execute()
+      .mapPartitionsInternal { iter =>
+        ArrowConverters
+          .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId, TaskContext.get)

Review Comment:
   Can you file a follow-up to limit the size of the batch? GRPC does not really like it when we send large messages ( > 4 MB).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013540451


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +131,36 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val schema = dataframe.schema
+    val maxRecordsPerBatch = dataframe.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = dataframe.sparkSession.sessionState.conf.sessionLocalTimeZone
+
+    val batches = dataframe.queryExecution.executedPlan
+      .execute()
+      .mapPartitionsInternal { iter =>
+        ArrowConverters
+          .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId, TaskContext.get)

Review Comment:
   ok, let me add a todo item



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] grundprinzip commented on pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
grundprinzip commented on PR #38468:
URL: https://github.com/apache/spark/pull/38468#issuecomment-1303010466

   @zhengruifeng can you please add test cases for things like
   
   `select * from table limit 0` where the optimizer decides there are no qualifying rows but we have to return an empty schema. Right now, when the query returns an empty result you will not return anything which is not a valid result.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014559938


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    val rows = dataframe.queryExecution.executedPlan.execute()
+    var numBatches = 0L
+
+    if (rows.getNumPartitions > 0) {
+      val batches = rows.mapPartitionsInternal { iter =>
+        ArrowConverters
+          .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+      }
+
+      val obj = new Object
+
+      val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+      val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) =>
+        obj.synchronized {
+          var batchId = partitionId.toLong << 33
+          taskResult.foreach { case (bytes, count, size) =>
+            val response = proto.Response.newBuilder().setClientId(clientId)
+            val batch = proto.Response.ArrowBatch
+              .newBuilder()
+              .setBatchId(batchId)
+              .setRowCount(count)
+              .setUncompressedBytes(size)
+              .setCompressedBytes(bytes.length)
+              .setData(ByteString.copyFrom(bytes))
+              .build()
+            response.setArrowBatch(batch)
+            responseObserver.onNext(response.build())

Review Comment:
   ok will update



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014560227


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    val rows = dataframe.queryExecution.executedPlan.execute()
+    var numBatches = 0L
+
+    if (rows.getNumPartitions > 0) {
+      val batches = rows.mapPartitionsInternal { iter =>
+        ArrowConverters
+          .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+      }
+
+      val obj = new Object
+
+      val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray

Review Comment:
   with batch_id, we can send higher partition before lower ones



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1015644446


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")

Review Comment:
   You can also use 2 fields in the proto: `partition_id` & `batch_id`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013195074


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -49,21 +51,33 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
     }
   }
 
-  def handlePlan(session: SparkSession, request: proto.Request): Unit = {
+  def handlePlan(session: SparkSession, request: Request): Unit = {
     // Extract the plan from the request and convert it to a logical plan
     val planner = new SparkConnectPlanner(request.getPlan.getRoot, session)
-    val rows =
-      Dataset.ofRows(session, planner.transform())
-    processRows(request.getClientId, rows)
+    val dataframe = Dataset.ofRows(session, planner.transform())
+    request.getPreferredResultType match {
+      case Request.ResultType.ArrowBatch =>
+        // check whether all data types are supported

Review Comment:
   What is not supported?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] grundprinzip commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
grundprinzip commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013642456


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,65 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String,
+      context: TaskContext): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    if (context != null) { // for test at driver
+      context.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {
+        val out = new ByteArrayOutputStream()
+        val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+        var rowCount = 0L
+        var estimatedSize = 0L
+        Utils.tryWithSafeFinally {
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+            estimatedSize += SizeEstimator.estimate(row)
+          }
+          arrowWriter.finish()
+          val batch = unloader.getRecordBatch()
+
+          MessageSerializer.serialize(writeChannel, arrowSchema)

Review Comment:
   Right now the client simply does a Pandas Union of all the batches coming from the server. The benefit is that we don't have to wait to serialize some data to the user. 
   
   Logically it makes sense to send the schema only once, but right now my suggestion would be to keep it like this to make the default consumption easier. Every Arrow batch we send is a fully contained Arrow IPC stream.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] grundprinzip commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
grundprinzip commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013643400


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,65 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String,
+      context: TaskContext): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    if (context != null) { // for test at driver
+      context.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {

Review Comment:
   did you check what happens if you do a `select * from table limit 0` I have previously had some challenges where the query returned 0 partitions and because of that we would not send a schema.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013549587


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,65 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String,
+      context: TaskContext): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    if (context != null) { // for test at driver
+      context.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {
+        val out = new ByteArrayOutputStream()
+        val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+        var rowCount = 0L
+        var estimatedSize = 0L
+        Utils.tryWithSafeFinally {
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+            estimatedSize += SizeEstimator.estimate(row)
+          }
+          arrowWriter.finish()
+          val batch = unloader.getRecordBatch()
+
+          MessageSerializer.serialize(writeChannel, arrowSchema)

Review Comment:
   i had tried to split schema and record batches, but didn't find an easy way to do this. Then made this change as per @grundprinzip and @HyukjinKwon  suggestion



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013550540


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +131,36 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val schema = dataframe.schema
+    val maxRecordsPerBatch = dataframe.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = dataframe.sparkSession.sessionState.conf.sessionLocalTimeZone
+
+    val batches = dataframe.queryExecution.executedPlan
+      .execute()
+      .mapPartitionsInternal { iter =>
+        ArrowConverters
+          .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId, TaskContext.get)
+      }
+
+    batches.toLocalIterator.foreach { case (bytes, count, size) =>

Review Comment:
   maybe we should also add a field `partitionId` in proto message, than we sort by it in client to keep the order



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014279677


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    val rows = dataframe.queryExecution.executedPlan.execute()
+    var numBatches = 0L
+
+    if (rows.getNumPartitions > 0) {
+      val batches = rows.mapPartitionsInternal { iter =>
+        ArrowConverters
+          .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+      }
+
+      val obj = new Object
+
+      val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+      val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) =>
+        obj.synchronized {
+          var batchId = partitionId.toLong << 33
+          taskResult.foreach { case (bytes, count, size) =>
+            val response = proto.Response.newBuilder().setClientId(clientId)
+            val batch = proto.Response.ArrowBatch
+              .newBuilder()
+              .setBatchId(batchId)
+              .setRowCount(count)
+              .setUncompressedBytes(size)
+              .setCompressedBytes(bytes.length)
+              .setData(ByteString.copyFrom(bytes))
+              .build()
+            response.setArrowBatch(batch)
+            responseObserver.onNext(response.build())

Review Comment:
   This callback is currently executed by the DAGScheduler thread, and if this expensive no job/stage can be scheduled during this call. We really should move this to a separate thread.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014270879


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    val rows = dataframe.queryExecution.executedPlan.execute()

Review Comment:
   You need to wrap this and all the code below in SQLExecution.withExecutionId(..) or use Dataset.withAction otherwise you break the UI and a bunch of other things.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018680299


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit

Review Comment:
   Is this used somewhere?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018679792


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records

Review Comment:
   I think you can remove this since we're already handling the max records?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019007963


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   And another question is, if there is no `Sort`, do we need to keep the partition order?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] grundprinzip commented on pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
grundprinzip commented on PR #38468:
URL: https://github.com/apache/spark/pull/38468#issuecomment-1310307256

   I would like to close the discussion on the ordered vs un-ordered result.
   
   1) For simple clients ordered results are what they expect and it follows the precedent of what users would expect from traditional database drivers.
   
   2) This is not a one way door, when we believe that we can no longer live with ordered results and it's causing issues, we will revisit this topic.
   
   It does not seem like having this discussion without appropriate evidence to support these two modes is supporting the efficient evolution of this feature.
   
   Thank you.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] cloud-fan commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
cloud-fan commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019151257


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   Agreed, it's also important to keep client implementations simple. This "async collect" should be OK in most cases.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] cloud-fan commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
cloud-fan commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019154009


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,97 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    Option(TaskContext.get).foreach {
+      _.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {
+        val out = new ByteArrayOutputStream()
+        val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+        var rowCount = 0L
+        var estimatedSize = SizeEstimator.estimate(arrowSchema) +
+          SizeEstimator.estimate(IpcOption.DEFAULT)
+        Utils.tryWithSafeFinally {
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+            estimatedSize += SizeEstimator.estimate(row)
+          }
+          arrowWriter.finish()
+          val batch = unloader.getRecordBatch()
+
+          MessageSerializer.serialize(writeChannel, arrowSchema)
+          MessageSerializer.serialize(writeChannel, batch)
+          ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
+
+          batch.close()
+        } {
+          arrowWriter.reset()
+        }
+
+        (out.toByteArray, rowCount, estimatedSize)
+      }
+    }
+  }
+
+  private[sql] def createEmptyArrowBatch(

Review Comment:
   how is this different from calling `toArrowBatchIterator` with an empty iterator?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018681018


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()
+            }
+            val partition = partitions(currentPartitionId)
+            partitions(currentPartitionId) = null
+            partition
+          }
+
+          // only send non-empty partitions
+          if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {
+            partition.foreach { case (bytes, count, size) =>
+              val response = proto.Response.newBuilder().setClientId(clientId)
+              val batch = proto.Response.ArrowBatch
+                .newBuilder()
+                .setRowCount(count)
+                .setUncompressedBytes(size)
+                .setCompressedBytes(bytes.length)
+                .setData(ByteString.copyFrom(bytes))
+                .build()
+              response.setArrowBatch(batch)
+              responseObserver.onNext(response.build())
+            }
+            numSent += 1

Review Comment:
   can we use a boolean?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018678853


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -48,19 +51,25 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
     }
   }
 
-  def handlePlan(session: SparkSession, request: proto.Request): Unit = {
+  def handlePlan(session: SparkSession, request: Request): Unit = {
     // Extract the plan from the request and convert it to a logical plan
     val planner = new SparkConnectPlanner(request.getPlan.getRoot, session)
-    val rows =
-      Dataset.ofRows(session, planner.transform())
-    processRows(request.getClientId, rows)
+    val dataframe = Dataset.ofRows(session, planner.transform())
+    // check whether all data types are supported
+    if (Try {
+        ArrowUtils.toArrowSchema(dataframe.schema, session.sessionState.conf.sessionLocalTimeZone)
+      }.isSuccess) {
+      processRowsAsArrowBatches(request.getClientId, dataframe)
+    } else {
+      processRowsAsJsonBatches(request.getClientId, dataframe)

Review Comment:
   ```
   try {
     ...
   } catch {
   ...
   }
   ```
   
   Or
   
   ```
   Try { ... }
     .map(_ => processRowsAsArrowBatches(request.getClientId, dataframe))
     .getOrElse(processRowsAsJsonBatches(request.getClientId, dataframe))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018682109


##########
python/pyspark/sql/connect/client.py:
##########
@@ -400,6 +400,14 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra
 
         if len(result_dfs) > 0:
             df = pd.concat(result_dfs)
+            del result_dfs

Review Comment:
   why?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018951362


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records

Review Comment:
   Another downside of large allocations is that the GC does not really like them. All large allocation (> 1 MB) are generally placed in the old generation immediately, which requires a full GC to clean-up.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018997361


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   How do you think this way? https://github.com/apache/spark/pull/38468#discussion_r1016118342



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019058105


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   In reduce phase, the task fetches map data of each partition in random order, w/o local sort, user still sees indeterminate data even the driver returns data by partition number.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019118121


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,92 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(

Review Comment:
   `toArrowBatchIterator` also write schema before each record batch, while `toBatchIterator ` just output record batch.
   
   and I am going to update `toArrowBatchIterator` in a follow-up to control each batch size < 4MB as per the suggestions https://github.com/apache/spark/pull/38468#discussion_r1018951362   https://github.com/apache/spark/pull/38468#discussion_r1013186031
   
   I think we can deduplicate the codes then



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019170138


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,97 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    Option(TaskContext.get).foreach {
+      _.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {
+        val out = new ByteArrayOutputStream()
+        val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+        var rowCount = 0L
+        var estimatedSize = SizeEstimator.estimate(arrowSchema) +
+          SizeEstimator.estimate(IpcOption.DEFAULT)
+        Utils.tryWithSafeFinally {
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+            estimatedSize += SizeEstimator.estimate(row)
+          }
+          arrowWriter.finish()
+          val batch = unloader.getRecordBatch()
+
+          MessageSerializer.serialize(writeChannel, arrowSchema)
+          MessageSerializer.serialize(writeChannel, batch)
+          ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
+
+          batch.close()
+        } {
+          arrowWriter.reset()
+        }
+
+        (out.toByteArray, rowCount, estimatedSize)
+      }
+    }
+  }
+
+  private[sql] def createEmptyArrowBatch(

Review Comment:
   calling toArrowBatchIterator with an empty iterator will return an empty iterator
   
   here needs an arrow batch with empty data



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019048009


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {
+          if (taskResult.exists(_._1.nonEmpty)) {
+            // only send non-empty partitions
+            val task = pool.submit(new Runnable {
+              override def run(): Unit = {
+                var batchId = partitionId.toLong << 33
+                taskResult.foreach { case (bytes, count, size) =>
+                  val response = proto.Response.newBuilder().setClientId(clientId)
+                  val batch = proto.Response.ArrowBatch
+                    .newBuilder()
+                    .setBatchId(batchId)
+                    .setRowCount(count)
+                    .setUncompressedBytes(size)
+                    .setCompressedBytes(bytes.length)
+                    .setData(ByteString.copyFrom(bytes))
+                    .build()
+                  response.setArrowBatch(batch)
+                  responseObserver.onNext(response.build())

Review Comment:
   To be clear, my suggestion is to defer the deserialization and block releasing in `JobWaiter#taskSucceeded` phase.
   
   The second way you suggested has many performance benefits but requires the client to communicate with the external storage service.
   
   The `IndirectTaskResult` way leverages the existing code and spark build-in block mechanism to transfer data, we can benefit w/ a little code modification, and we don't need to worry about result cleanup.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1017732568


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -86,12 +93,15 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
           val response = proto.Response.newBuilder().setClientId(clientId)
           val batch = proto.Response.JSONBatch
             .newBuilder()
+            .setPartitionId(-1)

Review Comment:
   What is this? If you don't use don't set it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018682801


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit

Review Comment:
   yes , it is used to change the returned type to `Unit`



##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()
+            }
+            val partition = partitions(currentPartitionId)
+            partitions(currentPartitionId) = null
+            partition
+          }
+
+          // only send non-empty partitions
+          if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {
+            partition.foreach { case (bytes, count, size) =>
+              val response = proto.Response.newBuilder().setClientId(clientId)
+              val batch = proto.Response.ArrowBatch
+                .newBuilder()
+                .setRowCount(count)
+                .setUncompressedBytes(size)
+                .setCompressedBytes(bytes.length)
+                .setData(ByteString.copyFrom(bytes))
+                .build()
+              response.setArrowBatch(batch)
+              responseObserver.onNext(response.build())
+            }
+            numSent += 1

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018683751


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   no, partitions can be fetched by random order. here wait for the `currentPartitionId`-th (start from 0) partition



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018688539


##########
python/pyspark/sql/connect/client.py:
##########
@@ -400,6 +400,14 @@ def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFra
 
         if len(result_dfs) > 0:
             df = pd.concat(result_dfs)
+            del result_dfs

Review Comment:
   wouldn't be needed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] cloud-fan commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
cloud-fan commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019100196


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   The current approach still has the risk of holding all the results in the driver memory (assuming the first partition comes last), which violates the design goal of Spark Connect.
   
   I think the Spark driver should send whichever partition that arrives to the client, and the client should allocate an array to hold arrow batches of all partitions. The client need to keep all the result in-memory anyway, so it's better to ask the client to buffer the results ad reorder them by partition id.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016358065


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")

Review Comment:
   > Why not use the main thread for this?
   
   ok will follow https://github.com/apache/spark/pull/38468#discussion_r1013184548
   
   > You can also use 2 fields in the proto: partition_id & batch_id
   
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on PR #38468:
URL: https://github.com/apache/spark/pull/38468#issuecomment-1309825071

   Logic-wise, makes sense.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018682607


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records

Review Comment:
   https://github.com/apache/spark/pull/38468#discussion_r1013186031 suggested control the batch size < 4MB
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018875460


##########
connector/connect/src/main/protobuf/spark/connect/base.proto:
##########
@@ -83,7 +83,6 @@ message Response {
     int64 uncompressed_bytes = 2;

Review Comment:
   yeah, will git rid of it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018928333


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,97 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    Option(TaskContext.get).foreach {
+      _.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {
+        val out = new ByteArrayOutputStream()
+        val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+        var rowCount = 0L
+        var estimatedSize = SizeEstimator.estimate(arrowSchema) +
+          SizeEstimator.estimate(IpcOption.DEFAULT)
+        Utils.tryWithSafeFinally {
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+            estimatedSize += SizeEstimator.estimate(row)
+          }
+          arrowWriter.finish()
+          val batch = unloader.getRecordBatch()
+
+          MessageSerializer.serialize(writeChannel, arrowSchema)
+          MessageSerializer.serialize(writeChannel, batch)
+          ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
+
+          batch.close()
+        } {
+          arrowWriter.reset()
+        }
+
+        (out.toByteArray, rowCount, estimatedSize)
+      }
+    }
+  }
+
+  private[sql] def createEmptyArrowBatch(

Review Comment:
   still don't figure out how to deduplicate, the `toArrowBatchIterator` should also return `rowCount`
   
   what about trying to do this after switch to `batch size < 4MB`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018995213


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   Yes. We can look into spilling to deal with these situations.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019007963


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   And another question is, if there is no `Sort`, do we need reverse the partition order?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] pan3793 commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
pan3793 commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019048009


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val pool = ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+      val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+      val rows = dataframe.queryExecution.executedPlan.execute()
+
+      if (rows.getNumPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], Long, Long)]) => {
+          if (taskResult.exists(_._1.nonEmpty)) {
+            // only send non-empty partitions
+            val task = pool.submit(new Runnable {
+              override def run(): Unit = {
+                var batchId = partitionId.toLong << 33
+                taskResult.foreach { case (bytes, count, size) =>
+                  val response = proto.Response.newBuilder().setClientId(clientId)
+                  val batch = proto.Response.ArrowBatch
+                    .newBuilder()
+                    .setBatchId(batchId)
+                    .setRowCount(count)
+                    .setUncompressedBytes(size)
+                    .setCompressedBytes(bytes.length)
+                    .setData(ByteString.copyFrom(bytes))
+                    .build()
+                  response.setArrowBatch(batch)
+                  responseObserver.onNext(response.build())

Review Comment:
   To be clear, my suggestion is to defer the deserialization and block releasing to `JobWaiter#taskSucceeded` phase.
   
   The second way you suggested has many performance benefits but requires the client to communicate with the external storage service.
   
   The `IndirectTaskResult` way leverages the existing code and spark build-in block mechanism to transfer data, we can benefit w/ a little code modification, and we don't need to worry about result cleanup.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] cloud-fan commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
cloud-fan commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019090800


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = collection.mutable.Map.empty[Int, Array[Batch]]

Review Comment:
   then we don't need a lock.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019117026


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+        val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          val i = 0 // Unit
+        }
+
+        spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          val partition = signal.synchronized {
+            while (partitions(currentPartitionId) == null) {
+              signal.wait()

Review Comment:
   That makes the client and the API more complicated. I don't want the implementors of the clients to deal with this. We can add an optimization for a dataframe that is unordered in a follow-up, but for now let's just merge the things that works.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019100602


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = collection.mutable.Map.empty[Int, Array[Batch]]

Review Comment:
   > can we apply the same idea to JSON batches?
   
   I think so, let's optimize it later



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng closed pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng closed pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect
URL: https://github.com/apache/spark/pull/38468


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018687996


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,97 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    Option(TaskContext.get).foreach {
+      _.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {
+        val out = new ByteArrayOutputStream()
+        val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+        var rowCount = 0L
+        var estimatedSize = SizeEstimator.estimate(arrowSchema) +
+          SizeEstimator.estimate(IpcOption.DEFAULT)
+        Utils.tryWithSafeFinally {
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+            estimatedSize += SizeEstimator.estimate(row)
+          }
+          arrowWriter.finish()
+          val batch = unloader.getRecordBatch()
+
+          MessageSerializer.serialize(writeChannel, arrowSchema)
+          MessageSerializer.serialize(writeChannel, batch)
+          ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
+
+          batch.close()
+        } {
+          arrowWriter.reset()
+        }
+
+        (out.toByteArray, rowCount, estimatedSize)
+      }
+    }
+  }
+
+  private[sql] def createEmptyArrowBatch(

Review Comment:
   Can we deduplicate the codes with `toBatchIterator`? e.g.) you can:
   
   ```scala
   private[sql] def createArrowBatch(
       handleBatch: RecordBatch => Unit,
       ...) = {
     ...
     handleBatch(unloader.getRecordBatch())
     ...
   }
   
   def toArrowBatchIterator = {
     ...
     createArrowBatch(..., { batch =>
       ...
     }
     ...
   }
   
   def toBatchIterator = {
     ...
     createArrowBatch(..., { batch =>
       ...
     }
     ...
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018939295


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long, Long)
+
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val partitions = Array.fill[Array[Batch]](numPartitions)(null)

Review Comment:
   will update soon



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013547085


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -49,21 +51,33 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
     }
   }
 
-  def handlePlan(session: SparkSession, request: proto.Request): Unit = {
+  def handlePlan(session: SparkSession, request: Request): Unit = {
     // Extract the plan from the request and convert it to a logical plan
     val planner = new SparkConnectPlanner(request.getPlan.getRoot, session)
-    val rows =
-      Dataset.ofRows(session, planner.transform())
-    processRows(request.getClientId, rows)
+    val dataframe = Dataset.ofRows(session, planner.transform())
+    request.getPreferredResultType match {
+      case Request.ResultType.ArrowBatch =>
+        // check whether all data types are supported

Review Comment:
   for example: CharType, VarcharType, UserDefinedType(like VectorUDT)
   
   supported list: 
   
   https://github.com/apache/spark/blob/1a90512f605c490255f7b38215c207e64621475b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala#L38-L60
   
   https://github.com/apache/spark/blob/7b8016a578f511d1c17b16393c487429ce08f132/python/pyspark/sql/pandas/types.py#L54-L120



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013556642


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -128,6 +128,65 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
+  private[sql] def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      timeZoneId: String,
+      context: TaskContext): Iterator[(Array[Byte], Long, Long)] = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+      "toArrowBatchIterator", 0, Long.MaxValue)
+
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val unloader = new VectorUnloader(root)
+    val arrowWriter = ArrowWriter.create(root)
+
+    if (context != null) { // for test at driver
+      context.addTaskCompletionListener[Unit] { _ =>
+        root.close()
+        allocator.close()
+      }
+    }
+
+    new Iterator[(Array[Byte], Long, Long)] {
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        false
+      }
+
+      override def next(): (Array[Byte], Long, Long) = {
+        val out = new ByteArrayOutputStream()
+        val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+        var rowCount = 0L
+        var estimatedSize = 0L
+        Utils.tryWithSafeFinally {
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+            estimatedSize += SizeEstimator.estimate(row)
+          }
+          arrowWriter.finish()
+          val batch = unloader.getRecordBatch()
+
+          MessageSerializer.serialize(writeChannel, arrowSchema)

Review Comment:
   The problem is that we carry Arrow record batch within each protobuf message. So, schema of each recordbatch has to be sent in order to read each protobuf message.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] hvanhovell commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
hvanhovell commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014270879


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    val rows = dataframe.queryExecution.executedPlan.execute()

Review Comment:
   You need to wrap this in SQLExecution.withExecutionId(..) or use Dataset.withAction otherwise you break the UI and a bunch of other things.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014559977


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    val rows = dataframe.queryExecution.executedPlan.execute()

Review Comment:
   good point



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1017740678


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -117,10 +127,99 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
       responseObserver.onNext(response.build())
     }
 
-    responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
     responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+    val spark = dataframe.sparkSession
+    val schema = dataframe.schema
+    // TODO: control the batch size instead of max records
+    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+    val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+    SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
+      val rows = dataframe.queryExecution.executedPlan.execute()
+      val numPartitions = rows.getNumPartitions
+      var numSent = 0
+
+      if (numPartitions > 0) {
+        val batches = rows.mapPartitionsInternal { iter =>
+          ArrowConverters
+            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+        }
+
+        val signal = new Object
+        val queue = collection.mutable.Queue.empty[(Int, Array[(Array[Byte], Long, Long)])]

Review Comment:
   ok, i will make sure the batches are sent in order from server side.
   
   then we don't need partition_id and batch_id any more



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org