You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/06/06 01:30:21 UTC

[spark] branch master updated: [SPARK-42626][CONNECT] Add Destructive Iterator for SparkResult

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 62338ed6cd9 [SPARK-42626][CONNECT] Add Destructive Iterator for SparkResult
62338ed6cd9 is described below

commit 62338ed6cd9fba8bb92ec11cea643077e4b69db4
Author: Tengfei Huang <te...@gmail.com>
AuthorDate: Mon Jun 5 21:30:02 2023 -0400

    [SPARK-42626][CONNECT] Add Destructive Iterator for SparkResult
    
    ### What changes were proposed in this pull request?
    Add a destructive iterator to SparkResult and change `Dataset.toLocalIterator` to use the desctructive iterator.
    With the desctructive iterator, we will:
    1. Close the `ColumarBatch` once its data got consumed;
    2. Remove the `ColumarBatch` from `SparkResult.batches`;
    
    ### Why are the changes needed?
    Instead of keeping everything in memory for the life time of SparkResult object, clean it up as soon as we know we are done with it.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT added.
    
    Closes #40610 from ivoson/SPARK-42626.
    
    Authored-by: Tengfei Huang <te...@gmail.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  3 +-
 .../spark/sql/connect/client/SparkResult.scala     | 43 ++++++++++++------
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  | 52 +++++++++++++++++++++-
 3 files changed, 81 insertions(+), 17 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 7a680bde7d3..eba425ce127 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2768,8 +2768,7 @@ class Dataset[T] private[sql] (
    * @since 3.4.0
    */
   def toLocalIterator(): java.util.Iterator[T] = {
-    // TODO make this a destructive iterator.
-    collectResult().iterator
+    collectResult().destructiveIterator
   }
 
   /**
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 49db44bd855..86a7cf846f2 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -46,7 +46,8 @@ private[sql] class SparkResult[T](
   private[this] var numRecords: Int = 0
   private[this] var structType: StructType = _
   private[this] var boundEncoder: ExpressionEncoder[T] = _
-  private[this] val batches = mutable.Buffer.empty[ColumnarBatch]
+  private[this] var nextBatchIndex: Int = 0
+  private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch]
 
   private def createEncoder(schema: StructType): ExpressionEncoder[T] = {
     val agnosticEncoder = if (encoder == UnboundRowEncoder) {
@@ -70,12 +71,12 @@ private[sql] class SparkResult[T](
         val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), allocator)
         try {
           val root = reader.getVectorSchemaRoot
-          if (batches.isEmpty) {
-            if (structType == null) {
-              // If the schema is not available yet, fallback to the schema from Arrow.
-              structType = ArrowUtils.fromArrowSchema(root.getSchema)
-            }
-            // TODO: create encoders that directly operate on arrow vectors.
+          if (structType == null) {
+            // If the schema is not available yet, fallback to the schema from Arrow.
+            structType = ArrowUtils.fromArrowSchema(root.getSchema)
+          }
+          // TODO: create encoders that directly operate on arrow vectors.
+          if (boundEncoder == null) {
             boundEncoder = createEncoder(structType).resolveAndBind(structType.toAttributes)
           }
           while (reader.loadNextBatch()) {
@@ -85,7 +86,8 @@ private[sql] class SparkResult[T](
               val vectors = root.getFieldVectors.asScala
                 .map(v => new ArrowColumnVector(transferToNewVector(v)))
                 .toArray[ColumnVector]
-              batches += new ColumnarBatch(vectors, rowCount)
+              idxToBatches.put(nextBatchIndex, new ColumnarBatch(vectors, rowCount))
+              nextBatchIndex += 1
               numRecords += rowCount
               if (stopOnFirstNonEmptyResponse) {
                 return true
@@ -142,24 +144,39 @@ private[sql] class SparkResult[T](
   /**
    * Returns an iterator over the contents of the result.
    */
-  def iterator: java.util.Iterator[T] with AutoCloseable = {
+  def iterator: java.util.Iterator[T] with AutoCloseable =
+    buildIterator(destructive = false)
+
+  /**
+   * Returns an destructive iterator over the contents of the result.
+   */
+  def destructiveIterator: java.util.Iterator[T] with AutoCloseable =
+    buildIterator(destructive = true)
+
+  private def buildIterator(destructive: Boolean): java.util.Iterator[T] with AutoCloseable = {
     new java.util.Iterator[T] with AutoCloseable {
       private[this] var batchIndex: Int = -1
       private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator()
       private[this] var deserializer: Deserializer[T] = _
+
       override def hasNext: Boolean = {
         if (iterator.hasNext) {
           return true
         }
+
         val nextBatchIndex = batchIndex + 1
-        val hasNextBatch = if (nextBatchIndex == batches.size) {
+        if (destructive) {
+          idxToBatches.remove(batchIndex).foreach(_.close())
+        }
+
+        val hasNextBatch = if (!idxToBatches.contains(nextBatchIndex)) {
           processResponses(stopOnFirstNonEmptyResponse = true)
         } else {
           true
         }
         if (hasNextBatch) {
           batchIndex = nextBatchIndex
-          iterator = batches(nextBatchIndex).rowIterator()
+          iterator = idxToBatches(nextBatchIndex).rowIterator()
           if (deserializer == null) {
             deserializer = boundEncoder.createDeserializer()
           }
@@ -182,8 +199,8 @@ private[sql] class SparkResult[T](
    * Close this result, freeing any underlying resources.
    */
   override def close(): Unit = {
-    batches.foreach(_.close())
+    idxToBatches.values.foreach(_.close())
   }
 
-  override def cleaner: AutoCloseable = AutoCloseables(batches.toSeq)
+  override def cleaner: AutoCloseable = AutoCloseables(idxToBatches.values.toSeq)
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 1a775f55ff5..bdef6b92ece 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -21,6 +21,7 @@ import java.nio.file.Files
 import java.util.Properties
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.concurrent.{ExecutionContext, Future}
 import scala.concurrent.duration._
 import scala.util.{Failure, Success}
@@ -30,21 +31,23 @@ import org.apache.commons.io.FileUtils
 import org.apache.commons.io.output.TeeOutputStream
 import org.apache.commons.lang3.{JavaVersion, SystemUtils}
 import org.scalactic.TolerantNumerics
+import org.scalatest.PrivateMethodTester
 import org.scalatest.concurrent.Eventually._
 
 import org.apache.spark.{SPARK_VERSION, SparkException}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
 import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.connect.client.SparkConnectClient
+import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
 import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
 import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.ThreadUtils
 
-class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
+class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester {
 
   // Spark Result
   test("spark result schema") {
@@ -890,6 +893,51 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
     assert(message.contains("PARSE_SYNTAX_ERROR"))
   }
 
+  test("Dataset result destructive iterator") {
+    // Helper methods for accessing private field `idxToBatches` from SparkResult
+    val _idxToBatches =
+      PrivateMethod[mutable.Map[Int, ColumnarBatch]](Symbol("idxToBatches"))
+
+    def getColumnarBatches(result: SparkResult[_]): Seq[ColumnarBatch] = {
+      val idxToBatches = result invokePrivate _idxToBatches()
+
+      // Sort by key to get stable results.
+      idxToBatches.toSeq.sortBy(_._1).map(_._2)
+    }
+
+    val df = spark
+      .range(0, 10, 1, 10)
+      .filter("id > 5 and id < 9")
+
+    df.withResult { result =>
+      try {
+        // build and verify the destructive iterator
+        val iterator = result.destructiveIterator
+        // batches is empty before traversing the result iterator
+        assert(getColumnarBatches(result).isEmpty)
+        var previousBatch: ColumnarBatch = null
+        val buffer = mutable.Buffer.empty[Long]
+        while (iterator.hasNext) {
+          // always having 1 batch, since a columnar batch will be removed and closed after
+          // its data got consumed.
+          val batches = getColumnarBatches(result)
+          assert(batches.size === 1)
+          assert(batches.head != previousBatch)
+          previousBatch = batches.head
+
+          buffer.append(iterator.next())
+        }
+        // Batches should be closed and removed after traversing all the records.
+        assert(getColumnarBatches(result).isEmpty)
+
+        val expectedResult = Seq(6L, 7L, 8L)
+        assert(buffer.size === 3 && expectedResult.forall(buffer.contains))
+      } finally {
+        result.close()
+      }
+    }
+  }
+
   test("SparkSession.createDataFrame - large data set") {
     val threshold = 1024 * 1024
     withSQLConf(SQLConf.LOCAL_RELATION_CACHE_THRESHOLD.key -> threshold.toString) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org