You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/11/16 23:20:47 UTC

[GitHub] [spark] HyukjinKwon commented on a diff in pull request #38618: [SPARK-41108][SPARK-41005][CONNECT][FOLLOW-UP] Deduplicate ArrowConverters codes

HyukjinKwon commented on code in PR #38618:
URL: https://github.com/apache/spark/pull/38618#discussion_r1024602372


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -71,158 +71,146 @@ private[sql] class ArrowBatchStreamWriter(
 }
 
 private[sql] object ArrowConverters extends Logging {
-
-  /**
-   * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
-   * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
-   */
-  private[sql] def toBatchIterator(
+  private[sql] class ArrowBatchIterator(
       rowIter: Iterator[InternalRow],
       schema: StructType,
-      maxRecordsPerBatch: Int,
+      maxRecordsPerBatch: Long,
       timeZoneId: String,
-      context: TaskContext): Iterator[Array[Byte]] = {
+      context: TaskContext) extends Iterator[Array[Byte]] {
 
-    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
-    val allocator =
-      ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue)
+    protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+    private val allocator =
+      ArrowUtils.rootAllocator.newChildAllocator(
+        s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)
 
-    val root = VectorSchemaRoot.create(arrowSchema, allocator)
-    val unloader = new VectorUnloader(root)
-    val arrowWriter = ArrowWriter.create(root)
+    private val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    protected val unloader = new VectorUnloader(root)
+    protected val arrowWriter = ArrowWriter.create(root)
 
-    context.addTaskCompletionListener[Unit] { _ =>
+    Option(context).foreach {_.addTaskCompletionListener[Unit] { _ =>
       root.close()
       allocator.close()
+    }}
+
+    override def hasNext: Boolean = rowIter.hasNext || {
+      root.close()
+      allocator.close()
+      false
     }
 
-    new Iterator[Array[Byte]] {
+    override def next(): Array[Byte] = {
+      val out = new ByteArrayOutputStream()
+      val writeChannel = new WriteChannel(Channels.newChannel(out))
 
-      override def hasNext: Boolean = rowIter.hasNext || {
-        root.close()
-        allocator.close()
-        false
+      Utils.tryWithSafeFinally {
+        var rowCount = 0L
+        while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+          val row = rowIter.next()
+          arrowWriter.write(row)
+          rowCount += 1
+        }
+        arrowWriter.finish()
+        val batch = unloader.getRecordBatch()
+        MessageSerializer.serialize(writeChannel, batch)
+        batch.close()
+      } {
+        arrowWriter.reset()
       }
 
-      override def next(): Array[Byte] = {
-        val out = new ByteArrayOutputStream()
-        val writeChannel = new WriteChannel(Channels.newChannel(out))
-
-        Utils.tryWithSafeFinally {
-          var rowCount = 0
-          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
-            val row = rowIter.next()
-            arrowWriter.write(row)
-            rowCount += 1
-          }
-          arrowWriter.finish()
-          val batch = unloader.getRecordBatch()
-          MessageSerializer.serialize(writeChannel, batch)
-          batch.close()
-        } {
-          arrowWriter.reset()
+      out.toByteArray
+    }
+  }
+
+  private[sql] class ArrowBatchWithSchemaIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Long,
+      maxEstimatedBatchSize: Long,
+      timeZoneId: String,
+      context: TaskContext)
+    extends ArrowBatchIterator(
+      rowIter, schema, maxRecordsPerBatch, timeZoneId, context) {
+
+    private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema)

Review Comment:
   @hvanhovell, this PR is virtually pure refactoring except the couple of points I mentioned in the PR description. For the question, it came from https://github.com/apache/spark/pull/38612 to estimate the size of the batch before creating an Arrow batch.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org