You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "zhenlineo (via GitHub)" <gi...@apache.org> on 2023/07/18 18:40:07 UTC

[GitHub] [spark] zhenlineo commented on a diff in pull request #42011: [SPARK-44396][Connect] Direct Arrow Deserialization

zhenlineo commented on code in PR #42011:
URL: https://github.com/apache/spark/pull/42011#discussion_r1267152175


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala:
##########
@@ -172,52 +171,93 @@ private[sql] class SparkResult[T](
 
   private def buildIterator(destructive: Boolean): java.util.Iterator[T] with AutoCloseable = {
     new java.util.Iterator[T] with AutoCloseable {
-      private[this] var batchIndex: Int = -1
-      private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator()
-      private[this] var deserializer: Deserializer[T] = _
+      private[this] var iterator: CloseableIterator[T] = _
 
-      override def hasNext: Boolean = {
-        if (iterator.hasNext) {
-          return true
-        }
-
-        val nextBatchIndex = batchIndex + 1
-        if (destructive) {
-          idxToBatches.remove(batchIndex).foreach(_.close())
+      private def initialize(): Unit = {
+        if (iterator == null) {
+          iterator = new ArrowDeserializingIterator(
+            createEncoder(encoder, schema),
+            new ConcatenatingArrowStreamReader(
+              allocator,
+              Iterator.single(new ResultMessageIterator(destructive)),
+              destructive))
         }
+      }
 
-        val hasNextBatch = if (!idxToBatches.contains(nextBatchIndex)) {
-          processResponses(stopOnFirstNonEmptyResponse = true)
-        } else {
-          true
-        }
-        if (hasNextBatch) {
-          batchIndex = nextBatchIndex
-          iterator = idxToBatches(nextBatchIndex).rowIterator()
-          if (deserializer == null) {
-            deserializer = boundEncoder.createDeserializer()
-          }
-        }
-        hasNextBatch
+      override def hasNext: Boolean = {
+        initialize()
+        iterator.hasNext
       }
 
       override def next(): T = {
-        if (!hasNext) {
-          throw new NoSuchElementException
-        }
-        deserializer(iterator.next())
+        initialize()
+        iterator.next()
       }
 
-      override def close(): Unit = SparkResult.this.close()
+      override def close(): Unit = {

Review Comment:
   So if not yet read any data using `hasNext` or `next`, the data is not consumed? Will the data be dropped? Is the cleaner cleaning the resources in a background thread?



##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala:
##########
@@ -76,61 +71,63 @@ private[sql] class SparkResult[T](
     }
   }
 
-  private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean = {
-    while (responses.hasNext) {
+  private def processResponses(
+      stopOnSchema: Boolean = false,
+      stopOnArrowSchema: Boolean = false,
+      stopOnFirstNonEmptyResponse: Boolean = false): Boolean = {
+    var nonEmpty = false
+    var stop = false
+    while (!stop && responses.hasNext) {
       val response = responses.next()
       if (response.hasSchema) {
         // The original schema should arrive before ArrowBatches.
         structType =
           DataTypeProtoConverter.toCatalystType(response.getSchema).asInstanceOf[StructType]
-      } else if (response.hasArrowBatch) {
+        stop |= stopOnSchema
+      }
+      if (response.hasArrowBatch) {
         val ipcStreamBytes = response.getArrowBatch.getData
-        val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), allocator)
-        try {
-          val root = reader.getVectorSchemaRoot
-          if (structType == null) {
-            // If the schema is not available yet, fallback to the schema from Arrow.
-            structType = ArrowUtils.fromArrowSchema(root.getSchema)
-          }
-          // TODO: create encoders that directly operate on arrow vectors.
-          if (boundEncoder == null) {
-            boundEncoder = createEncoder(structType)
-              .resolveAndBind(DataTypeUtils.toAttributes(structType))
-          }
-          while (reader.loadNextBatch()) {
-            val rowCount = root.getRowCount
-            if (rowCount > 0) {
-              val vectors = root.getFieldVectors.asScala
-                .map(v => new ArrowColumnVector(transferToNewVector(v)))
-                .toArray[ColumnVector]
-              idxToBatches.put(nextBatchIndex, new ColumnarBatch(vectors, rowCount))
-              nextBatchIndex += 1
-              numRecords += rowCount
-              if (stopOnFirstNonEmptyResponse) {
-                return true
-              }
-            }
+        val reader = new MessageIterator(ipcStreamBytes.newInput(), allocator)
+        if (arrowSchema == null) {
+          arrowSchema = reader.schema
+          stop |= stopOnArrowSchema
+        } else if (arrowSchema != reader.schema) {
+          // Uh oh...

Review Comment:
   Maybe throw an IllegalStateException or assert something rather than quietly drop? Or we need to at least doc when this happens, which schema to use.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org