You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/06/13 11:10:28 UTC

[spark] branch master updated: [SPARK-39301][SQL][PYTHON] Leverage LocalRelation and respect Arrow batch size in createDataFrame with Arrow optimization

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d3d22928d4f [SPARK-39301][SQL][PYTHON] Leverage LocalRelation and respect Arrow batch size in createDataFrame with Arrow optimization
d3d22928d4f is described below

commit d3d22928d4fca1cd71f96a1308f0cc5d00120ad5
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Mon Jun 13 20:10:17 2022 +0900

    [SPARK-39301][SQL][PYTHON] Leverage LocalRelation and respect Arrow batch size in createDataFrame with Arrow optimization
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to use `LocalRelation` instead of `LogicalRDD` when creating a (small) DataFrame with Arrow optimization, which passes the data as a local data in the driver side (which is consistent with Scala code path).
    
    Namely:
    
    ```python
    import pandas as pd
    spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", True)
    spark.createDataFrame(pd.DataFrame({'a': [1, 2, 3, 4]})).explain(True)
    ```
    
    Before
    
    ```
    == Parsed Logical Plan ==
    LogicalRDD [a#0L], false
    
    == Analyzed Logical Plan ==
    a: bigint
    LogicalRDD [a#0L], false
    
    == Optimized Logical Plan ==
    LogicalRDD [a#0L], false
    
    == Physical Plan ==
    *(1) Scan ExistingRDD arrow[a#0L]
    ```
    
    After
    
    ```
    == Parsed Logical Plan ==
    LocalRelation [a#0L]
    
    == Analyzed Logical Plan ==
    a: bigint
    LocalRelation [a#0L]
    
    == Optimized Logical Plan ==
    LocalRelation [a#0L]
    
    == Physical Plan ==
    LocalTableScan [a#0L]
    ```
    
    This is controlled by a new configuration `spark.sql.execution.arrow.localRelationThreshold` defaulting to 48MB. This default was picked by benchmark I ran below.
    
    In addition, this PR also fixes `createDataFrame` to respect `spark.sql.execution.arrow.maxRecordsPerBatch` configuration when creating Arrow bathes. Previously, we divided the input pandas DataFrame by the default partition number which forced users to set `spark.rpc.message.maxSize` when the input pandas DataFrame is too large. See the benchmark performed below.
    
    ### Why are the changes needed?
    
    We have some nice optimization for `LocalRelation` (e.g., `ConvertToLocalRelation`). For example, the stats are fully known when you use `LocalRelation`. With `LogicalRDD`, many optimizations cannot be applied. Even in some cases (e.g., `executeCollect`), we can avoid creating `RDD`s too.
    
    For respecting `spark.sql.execution.arrow.maxRecordsPerBatch`, 1. we can avoid forcing users to set `spark.rpc.message.maxSize`, and 2. I believe the configuration is supposed to be respected for all code path that creates Arrow batches if possible.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it is an optimization. The number of partitions can be different, but that should be internal.
    
    ### How was this patch tested?
    
    - Manually tested.
    - Added a unittest.
    - I did two benchmark tests with 1 Driver & 4 Workers (i3.xlarge), see below.
    
    #### Benchmark 1 (best cases)
    
    ```python
    import time
    import random
    import string
    
    import pandas as pd
    
    names = [random.choice(list(string.ascii_lowercase)) for i in range(1000)]
    ages = [random.randint(0, 100) for i in range(1000)]
    l = list(zip(names, ages))
    d = [{'name': a_name, 'age': an_age} for a_name, an_age in l]
    pdf = pd.DataFrame({'name': names, 'age': ages})
    spark.range(1).count()  # heat up
    
    start = time.time()
    for _ in range(100):
        _ = spark.createDataFrame(pdf)
    
    end = time.time()
    print(end - start)
    ```
    
    Before
    
    10.250491698582968
    
    After
    
    6.004616181055705
    
    #### Benchmark 2 (worst cases)
    
    ```bash
    curl -O https://eforexcel.com/wp/wp-content/uploads/2020/09/HR2m.zip
    unzip HR2m.zip
    ```
    
    ```python
    import pandas as pd
    pdf = pd.read_csv("HR2m.csv")
    pdf23 = pdf.iloc[:int(len(pdf)/32)]
    pdf45 = pdf.iloc[:int(len(pdf)/16)]
    pdf90 = pdf.iloc[:int(len(pdf)/8)]
    pdf175 = pdf.iloc[:int(len(pdf)/4)]
    pdf350 = pdf.iloc[:int(len(pdf)/2)]
    pdf700 = pdf.iloc[:int(len(pdf))]
    pdf2gb = pd.concat([pdf, pdf, pdf])
    pdf5gb = pd.concat([pdf2gb, pdf2gb])
    
    spark.createDataFrame(pdf23)._jdf.rdd().count()  # explicitly create RDD.
    ...
    ```
    
    Before
    
    23MB: 1.02 seconds
    45MB: 1.69 seconds
    90MB: 2.38 seconds
    175MB: 3.19 seconds
    350MB: 6.10 seconds
    2GB: 43.21 seconds
    5GB: X (threw an exception that says to set 'spark.rpc.message.size' higher)
    
    After
    
    23MB: 1.31 seconds (local collection is used)
    45MB: 2.47 seconds (local collection is used)
    90MB: 1.79 seconds
    175MB: 3.22 seconds
    350MB: 6.41 seconds
    2GB: 47.12 seconds
    5GB: 1.29 minutes
    
    **NOTE** that the performance varies depending on network stability, and the numbers above are from second run (it's not the average).
    
    Closes #36683 from HyukjinKwon/SPARK-39301.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/context.py                          | 10 ++--
 python/pyspark/sql/pandas/conversion.py            | 12 ++---
 python/pyspark/sql/tests/test_arrow.py             | 12 +++++
 .../org/apache/spark/sql/internal/SQLConf.scala    | 14 +++++
 .../spark/sql/api/python/PythonSQLUtils.scala      | 38 ++++++-------
 .../org/apache/spark/sql/api/r/SQLUtils.scala      | 12 ++++-
 .../sql/execution/arrow/ArrowConverters.scala      | 62 +++++++++++++++-------
 7 files changed, 109 insertions(+), 51 deletions(-)

diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 59b5fa7f3a4..11d75f4f99a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -679,10 +679,10 @@ class SparkContext:
         data: Iterable[T],
         serializer: Serializer,
         reader_func: Callable,
-        createRDDServer: Callable,
+        server_func: Callable,
     ) -> JavaObject:
         """
-        Using py4j to send a large dataset to the jvm is really slow, so we use either a file
+        Using Py4J to send a large dataset to the jvm is slow, so we use either a file
         or a socket if we have encryption enabled.
 
         Examples
@@ -693,13 +693,13 @@ class SparkContext:
         reader_func : function
             A function which takes a filename and reads in the data in the jvm and
             returns a JavaRDD. Only used when encryption is disabled.
-        createRDDServer : function
-            A function which creates a PythonRDDServer in the jvm to
+        server_func : function
+            A function which creates a SocketAuthServer in the JVM to
             accept the serialized data, for use when encryption is enabled.
         """
         if self._encryption_enabled:
             # with encryption, we open a server in java and send the data directly
-            server = createRDDServer()
+            server = server_func()
             (sock_file, _) = local_connect_and_auth(server.port(), server.secret())
             chunked_out = ChunkedStream(sock_file, 8192)
             serializer.dump_stream(data, chunked_out)
diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py
index fff0bac5480..119a9bf315c 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -596,7 +596,7 @@ class SparkConversionMixin:
             ]
 
         # Slice the DataFrame to be batched
-        step = -(-len(pdf) // self.sparkContext.defaultParallelism)  # round int up
+        step = self._jconf.arrowMaxRecordsPerBatch()
         pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step))
 
         # Create list of Arrow (columns, type) for serializer dump_stream
@@ -613,16 +613,16 @@ class SparkConversionMixin:
 
         @no_type_check
         def reader_func(temp_filename):
-            return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsparkSession, temp_filename)
+            return self._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)
 
         @no_type_check
-        def create_RDD_server():
-            return self._jvm.ArrowRDDServer(jsparkSession)
+        def create_iter_server():
+            return self._jvm.ArrowIteratorServer()
 
         # Create Spark DataFrame from Arrow stream file, using one batch per partition
-        jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
+        jiter = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_iter_server)
         assert self._jvm is not None
-        jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsparkSession)
+        jdf = self._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession)
         df = DataFrame(jdf, self)
         df._schema = schema
         return df
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index b737848b11a..9b1b204542b 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -803,6 +803,18 @@ class EncryptionArrowTests(ArrowTests):
         return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true")
 
 
+class RDDBasedArrowTests(ArrowTests):
+    @classmethod
+    def conf(cls):
+        return (
+            super(RDDBasedArrowTests, cls)
+            .conf()
+            .set("spark.sql.execution.arrow.localRelationThreshold", "0")
+            # to test multiple partitions
+            .set("spark.sql.execution.arrow.maxRecordsPerBatch", "2")
+        )
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.test_arrow import *  # noqa: F401
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b8a752e90ec..4b64d91e56a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2576,6 +2576,18 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val ARROW_LOCAL_RELATION_THRESHOLD =
+    buildConf("spark.sql.execution.arrow.localRelationThreshold")
+      .doc(
+        "When converting Arrow batches to Spark DataFrame, local collections are used in the " +
+          "driver side if the byte size of Arrow batches is smaller than this threshold. " +
+          "Otherwise, the Arrow batches are sent and deserialized to Spark internal rows " +
+          "in the executors.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .checkValue(_ >= 0, "This value must be equal to or greater than 0.")
+      .createWithDefaultString("48MB")
+
   val PYSPARK_JVM_STACKTRACE_ENABLED =
     buildConf("spark.sql.pyspark.jvmStacktrace.enabled")
       .doc("When true, it shows the JVM stacktrace in the user-facing PySpark exception " +
@@ -4418,6 +4430,8 @@ class SQLConf extends Serializable with Logging {
 
   def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED)
 
+  def arrowLocalRelationThreshold: Long = getConf(ARROW_LOCAL_RELATION_THRESHOLD)
+
   def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED)
 
   def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index fd689bf502a..a3ba8636233 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -18,15 +18,15 @@
 package org.apache.spark.sql.api.python
 
 import java.io.InputStream
+import java.net.Socket
 import java.nio.channels.Channels
 import java.util.Locale
 
 import net.razorvine.pickle.Pickler
 
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.api.python.PythonRDDServer
+import org.apache.spark.api.python.DechunkedInputStream
 import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
+import org.apache.spark.security.SocketAuthServer
 import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
@@ -70,22 +70,22 @@ private[sql] object PythonSQLUtils extends Logging {
     SQLConf.get.timestampType == org.apache.spark.sql.types.TimestampNTZType
 
   /**
-   * Python callable function to read a file in Arrow stream format and create a [[RDD]]
-   * using each serialized ArrowRecordBatch as a partition.
+   * Python callable function to read a file in Arrow stream format and create an iterator
+   * of serialized ArrowRecordBatches.
    */
-  def readArrowStreamFromFile(session: SparkSession, filename: String): JavaRDD[Array[Byte]] = {
-    ArrowConverters.readArrowStreamFromFile(session, filename)
+  def readArrowStreamFromFile(filename: String): Iterator[Array[Byte]] = {
+    ArrowConverters.readArrowStreamFromFile(filename).iterator
   }
 
   /**
    * Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
-   * from an RDD.
+   * from the Arrow batch iterator.
    */
   def toDataFrame(
-      arrowBatchRDD: JavaRDD[Array[Byte]],
+      arrowBatches: Iterator[Array[Byte]],
       schemaString: String,
       session: SparkSession): DataFrame = {
-    ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, session)
+    ArrowConverters.toDataFrame(arrowBatches, schemaString, session)
   }
 
   def explainString(queryExecution: QueryExecution, mode: String): String = {
@@ -137,16 +137,16 @@ private[sql] object PythonSQLUtils extends Logging {
 }
 
 /**
- * Helper for making a dataframe from arrow data from data sent from python over a socket.  This is
+ * Helper for making a dataframe from Arrow data from data sent from python over a socket. This is
  * used when encryption is enabled, and we don't want to write data to a file.
  */
-private[sql] class ArrowRDDServer(session: SparkSession) extends PythonRDDServer {
-
-  override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
-    // Create array to consume iterator so that we can safely close the inputStream
-    val batches = ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray
-    // Parallelize the record batches to create an RDD
-    JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length))
+private[spark] class ArrowIteratorServer
+  extends SocketAuthServer[Iterator[Array[Byte]]]("pyspark-arrow-batches-server") {
+
+  def handleConnection(sock: Socket): Iterator[Array[Byte]] = {
+    val in = sock.getInputStream()
+    val dechunkedInput: InputStream = new DechunkedInputStream(in)
+    // Create array to consume iterator so that we can safely close the file
+    ArrowConverters.getBatchesFromStream(Channels.newChannel(dechunkedInput)).toArray.iterator
   }
-
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 7831ddee4f9..f58afcfa05d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -23,6 +23,7 @@ import java.util.{Locale, Map => JMap}
 import scala.collection.JavaConverters._
 import scala.util.matching.Regex
 
+import org.apache.spark.TaskContext
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.api.r.SerDe
 import org.apache.spark.broadcast.Broadcast
@@ -230,7 +231,9 @@ private[sql] object SQLUtils extends Logging {
   def readArrowStreamFromFile(
       sparkSession: SparkSession,
       filename: String): JavaRDD[Array[Byte]] = {
-    ArrowConverters.readArrowStreamFromFile(sparkSession, filename)
+    // Parallelize the record batches to create an RDD
+    val batches = ArrowConverters.readArrowStreamFromFile(filename)
+    JavaRDD.fromRDD(sparkSession.sparkContext.parallelize(batches, batches.length))
   }
 
   /**
@@ -241,6 +244,11 @@ private[sql] object SQLUtils extends Logging {
       arrowBatchRDD: JavaRDD[Array[Byte]],
       schema: StructType,
       sparkSession: SparkSession): DataFrame = {
-    ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession)
+    val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
+    val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
+      val context = TaskContext.get()
+      ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
+    }
+    sparkSession.internalCreateDataFrame(rdd.setName("arrow"), schema)
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 93ff276529d..bded158645c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -29,10 +29,12 @@ import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel
 import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer}
 
 import org.apache.spark.TaskContext
-import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.internal.Logging
 import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
@@ -68,7 +70,7 @@ private[sql] class ArrowBatchStreamWriter(
   }
 }
 
-private[sql] object ArrowConverters {
+private[sql] object ArrowConverters extends Logging {
 
   /**
    * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
@@ -143,7 +145,7 @@ private[sql] object ArrowConverters {
     new Iterator[InternalRow] {
       private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty
 
-      context.addTaskCompletionListener[Unit] { _ =>
+      if (context != null) context.addTaskCompletionListener[Unit] { _ =>
         root.close()
         allocator.close()
       }
@@ -190,32 +192,54 @@ private[sql] object ArrowConverters {
   }
 
   /**
-   * Create a DataFrame from an RDD of serialized ArrowRecordBatches.
+   * Create a DataFrame from an iterator of serialized ArrowRecordBatches.
    */
-  private[sql] def toDataFrame(
-      arrowBatchRDD: JavaRDD[Array[Byte]],
+  /**
+   * Create a DataFrame from an iterator of serialized ArrowRecordBatches.
+   */
+  def toDataFrame(
+      arrowBatches: Iterator[Array[Byte]],
       schemaString: String,
       session: SparkSession): DataFrame = {
     val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
-    val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
-    val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
-      val context = TaskContext.get()
-      ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
+    val attrs = schema.toAttributes
+    val batchesInDriver = arrowBatches.toArray
+    val shouldUseRDD = session.sessionState.conf
+      .arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum
+
+    if (shouldUseRDD) {
+      logDebug("Using RDD-based createDataFrame with Arrow optimization.")
+      val timezone = session.sessionState.conf.sessionLocalTimeZone
+      val rdd = session.sparkContext.parallelize(batchesInDriver, batchesInDriver.length)
+        .mapPartitions { batchesInExecutors =>
+          ArrowConverters.fromBatchIterator(
+            batchesInExecutors,
+            schema,
+            timezone,
+            TaskContext.get())
+        }
+      session.internalCreateDataFrame(rdd.setName("arrow"), schema)
+    } else {
+      logDebug("Using LocalRelation in createDataFrame with Arrow optimization.")
+      val data = ArrowConverters.fromBatchIterator(
+        batchesInDriver.toIterator,
+        schema,
+        session.sessionState.conf.sessionLocalTimeZone,
+        TaskContext.get())
+
+      // Project/copy it. Otherwise, the Arrow column vectors will be closed and released out.
+      val proj = UnsafeProjection.create(attrs, attrs)
+      Dataset.ofRows(session, LocalRelation(attrs, data.map(r => proj(r).copy()).toArray))
     }
-    session.internalCreateDataFrame(rdd.setName("arrow"), schema)
   }
 
   /**
-   * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches.
+   * Read a file as an Arrow stream and return an array of serialized ArrowRecordBatches.
    */
-  private[sql] def readArrowStreamFromFile(
-      session: SparkSession,
-      filename: String): JavaRDD[Array[Byte]] = {
+  private[sql] def readArrowStreamFromFile(filename: String): Array[Array[Byte]] = {
     Utils.tryWithResource(new FileInputStream(filename)) { fileStream =>
       // Create array to consume iterator so that we can safely close the file
-      val batches = getBatchesFromStream(fileStream.getChannel).toArray
-      // Parallelize the record batches to create an RDD
-      JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length))
+      getBatchesFromStream(fileStream.getChannel).toArray
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org