You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/06/13 11:10:28 UTC
[spark] branch master updated: [SPARK-39301][SQL][PYTHON] Leverage LocalRelation and respect Arrow batch size in createDataFrame with Arrow optimization
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d3d22928d4f [SPARK-39301][SQL][PYTHON] Leverage LocalRelation and respect Arrow batch size in createDataFrame with Arrow optimization
d3d22928d4f is described below
commit d3d22928d4fca1cd71f96a1308f0cc5d00120ad5
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Mon Jun 13 20:10:17 2022 +0900
[SPARK-39301][SQL][PYTHON] Leverage LocalRelation and respect Arrow batch size in createDataFrame with Arrow optimization
### What changes were proposed in this pull request?
This PR proposes to use `LocalRelation` instead of `LogicalRDD` when creating a (small) DataFrame with Arrow optimization, which passes the data as a local data in the driver side (which is consistent with Scala code path).
Namely:
```python
import pandas as pd
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", True)
spark.createDataFrame(pd.DataFrame({'a': [1, 2, 3, 4]})).explain(True)
```
Before
```
== Parsed Logical Plan ==
LogicalRDD [a#0L], false
== Analyzed Logical Plan ==
a: bigint
LogicalRDD [a#0L], false
== Optimized Logical Plan ==
LogicalRDD [a#0L], false
== Physical Plan ==
*(1) Scan ExistingRDD arrow[a#0L]
```
After
```
== Parsed Logical Plan ==
LocalRelation [a#0L]
== Analyzed Logical Plan ==
a: bigint
LocalRelation [a#0L]
== Optimized Logical Plan ==
LocalRelation [a#0L]
== Physical Plan ==
LocalTableScan [a#0L]
```
This is controlled by a new configuration `spark.sql.execution.arrow.localRelationThreshold` defaulting to 48MB. This default was picked by benchmark I ran below.
In addition, this PR also fixes `createDataFrame` to respect `spark.sql.execution.arrow.maxRecordsPerBatch` configuration when creating Arrow bathes. Previously, we divided the input pandas DataFrame by the default partition number which forced users to set `spark.rpc.message.maxSize` when the input pandas DataFrame is too large. See the benchmark performed below.
### Why are the changes needed?
We have some nice optimization for `LocalRelation` (e.g., `ConvertToLocalRelation`). For example, the stats are fully known when you use `LocalRelation`. With `LogicalRDD`, many optimizations cannot be applied. Even in some cases (e.g., `executeCollect`), we can avoid creating `RDD`s too.
For respecting `spark.sql.execution.arrow.maxRecordsPerBatch`, 1. we can avoid forcing users to set `spark.rpc.message.maxSize`, and 2. I believe the configuration is supposed to be respected for all code path that creates Arrow batches if possible.
### Does this PR introduce _any_ user-facing change?
No, it is an optimization. The number of partitions can be different, but that should be internal.
### How was this patch tested?
- Manually tested.
- Added a unittest.
- I did two benchmark tests with 1 Driver & 4 Workers (i3.xlarge), see below.
#### Benchmark 1 (best cases)
```python
import time
import random
import string
import pandas as pd
names = [random.choice(list(string.ascii_lowercase)) for i in range(1000)]
ages = [random.randint(0, 100) for i in range(1000)]
l = list(zip(names, ages))
d = [{'name': a_name, 'age': an_age} for a_name, an_age in l]
pdf = pd.DataFrame({'name': names, 'age': ages})
spark.range(1).count() # heat up
start = time.time()
for _ in range(100):
_ = spark.createDataFrame(pdf)
end = time.time()
print(end - start)
```
Before
10.250491698582968
After
6.004616181055705
#### Benchmark 2 (worst cases)
```bash
curl -O https://eforexcel.com/wp/wp-content/uploads/2020/09/HR2m.zip
unzip HR2m.zip
```
```python
import pandas as pd
pdf = pd.read_csv("HR2m.csv")
pdf23 = pdf.iloc[:int(len(pdf)/32)]
pdf45 = pdf.iloc[:int(len(pdf)/16)]
pdf90 = pdf.iloc[:int(len(pdf)/8)]
pdf175 = pdf.iloc[:int(len(pdf)/4)]
pdf350 = pdf.iloc[:int(len(pdf)/2)]
pdf700 = pdf.iloc[:int(len(pdf))]
pdf2gb = pd.concat([pdf, pdf, pdf])
pdf5gb = pd.concat([pdf2gb, pdf2gb])
spark.createDataFrame(pdf23)._jdf.rdd().count() # explicitly create RDD.
...
```
Before
23MB: 1.02 seconds
45MB: 1.69 seconds
90MB: 2.38 seconds
175MB: 3.19 seconds
350MB: 6.10 seconds
2GB: 43.21 seconds
5GB: X (threw an exception that says to set 'spark.rpc.message.size' higher)
After
23MB: 1.31 seconds (local collection is used)
45MB: 2.47 seconds (local collection is used)
90MB: 1.79 seconds
175MB: 3.22 seconds
350MB: 6.41 seconds
2GB: 47.12 seconds
5GB: 1.29 minutes
**NOTE** that the performance varies depending on network stability, and the numbers above are from second run (it's not the average).
Closes #36683 from HyukjinKwon/SPARK-39301.
Authored-by: Hyukjin Kwon <gu...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/context.py | 10 ++--
python/pyspark/sql/pandas/conversion.py | 12 ++---
python/pyspark/sql/tests/test_arrow.py | 12 +++++
.../org/apache/spark/sql/internal/SQLConf.scala | 14 +++++
.../spark/sql/api/python/PythonSQLUtils.scala | 38 ++++++-------
.../org/apache/spark/sql/api/r/SQLUtils.scala | 12 ++++-
.../sql/execution/arrow/ArrowConverters.scala | 62 +++++++++++++++-------
7 files changed, 109 insertions(+), 51 deletions(-)
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 59b5fa7f3a4..11d75f4f99a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -679,10 +679,10 @@ class SparkContext:
data: Iterable[T],
serializer: Serializer,
reader_func: Callable,
- createRDDServer: Callable,
+ server_func: Callable,
) -> JavaObject:
"""
- Using py4j to send a large dataset to the jvm is really slow, so we use either a file
+ Using Py4J to send a large dataset to the jvm is slow, so we use either a file
or a socket if we have encryption enabled.
Examples
@@ -693,13 +693,13 @@ class SparkContext:
reader_func : function
A function which takes a filename and reads in the data in the jvm and
returns a JavaRDD. Only used when encryption is disabled.
- createRDDServer : function
- A function which creates a PythonRDDServer in the jvm to
+ server_func : function
+ A function which creates a SocketAuthServer in the JVM to
accept the serialized data, for use when encryption is enabled.
"""
if self._encryption_enabled:
# with encryption, we open a server in java and send the data directly
- server = createRDDServer()
+ server = server_func()
(sock_file, _) = local_connect_and_auth(server.port(), server.secret())
chunked_out = ChunkedStream(sock_file, 8192)
serializer.dump_stream(data, chunked_out)
diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py
index fff0bac5480..119a9bf315c 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -596,7 +596,7 @@ class SparkConversionMixin:
]
# Slice the DataFrame to be batched
- step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
+ step = self._jconf.arrowMaxRecordsPerBatch()
pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step))
# Create list of Arrow (columns, type) for serializer dump_stream
@@ -613,16 +613,16 @@ class SparkConversionMixin:
@no_type_check
def reader_func(temp_filename):
- return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsparkSession, temp_filename)
+ return self._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)
@no_type_check
- def create_RDD_server():
- return self._jvm.ArrowRDDServer(jsparkSession)
+ def create_iter_server():
+ return self._jvm.ArrowIteratorServer()
# Create Spark DataFrame from Arrow stream file, using one batch per partition
- jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
+ jiter = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_iter_server)
assert self._jvm is not None
- jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsparkSession)
+ jdf = self._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession)
df = DataFrame(jdf, self)
df._schema = schema
return df
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index b737848b11a..9b1b204542b 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -803,6 +803,18 @@ class EncryptionArrowTests(ArrowTests):
return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true")
+class RDDBasedArrowTests(ArrowTests):
+ @classmethod
+ def conf(cls):
+ return (
+ super(RDDBasedArrowTests, cls)
+ .conf()
+ .set("spark.sql.execution.arrow.localRelationThreshold", "0")
+ # to test multiple partitions
+ .set("spark.sql.execution.arrow.maxRecordsPerBatch", "2")
+ )
+
+
if __name__ == "__main__":
from pyspark.sql.tests.test_arrow import * # noqa: F401
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b8a752e90ec..4b64d91e56a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2576,6 +2576,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val ARROW_LOCAL_RELATION_THRESHOLD =
+ buildConf("spark.sql.execution.arrow.localRelationThreshold")
+ .doc(
+ "When converting Arrow batches to Spark DataFrame, local collections are used in the " +
+ "driver side if the byte size of Arrow batches is smaller than this threshold. " +
+ "Otherwise, the Arrow batches are sent and deserialized to Spark internal rows " +
+ "in the executors.")
+ .version("3.4.0")
+ .bytesConf(ByteUnit.BYTE)
+ .checkValue(_ >= 0, "This value must be equal to or greater than 0.")
+ .createWithDefaultString("48MB")
+
val PYSPARK_JVM_STACKTRACE_ENABLED =
buildConf("spark.sql.pyspark.jvmStacktrace.enabled")
.doc("When true, it shows the JVM stacktrace in the user-facing PySpark exception " +
@@ -4418,6 +4430,8 @@ class SQLConf extends Serializable with Logging {
def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED)
+ def arrowLocalRelationThreshold: Long = getConf(ARROW_LOCAL_RELATION_THRESHOLD)
+
def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED)
def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index fd689bf502a..a3ba8636233 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -18,15 +18,15 @@
package org.apache.spark.sql.api.python
import java.io.InputStream
+import java.net.Socket
import java.nio.channels.Channels
import java.util.Locale
import net.razorvine.pickle.Pickler
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.api.python.PythonRDDServer
+import org.apache.spark.api.python.DechunkedInputStream
import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
+import org.apache.spark.security.SocketAuthServer
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
@@ -70,22 +70,22 @@ private[sql] object PythonSQLUtils extends Logging {
SQLConf.get.timestampType == org.apache.spark.sql.types.TimestampNTZType
/**
- * Python callable function to read a file in Arrow stream format and create a [[RDD]]
- * using each serialized ArrowRecordBatch as a partition.
+ * Python callable function to read a file in Arrow stream format and create an iterator
+ * of serialized ArrowRecordBatches.
*/
- def readArrowStreamFromFile(session: SparkSession, filename: String): JavaRDD[Array[Byte]] = {
- ArrowConverters.readArrowStreamFromFile(session, filename)
+ def readArrowStreamFromFile(filename: String): Iterator[Array[Byte]] = {
+ ArrowConverters.readArrowStreamFromFile(filename).iterator
}
/**
* Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
- * from an RDD.
+ * from the Arrow batch iterator.
*/
def toDataFrame(
- arrowBatchRDD: JavaRDD[Array[Byte]],
+ arrowBatches: Iterator[Array[Byte]],
schemaString: String,
session: SparkSession): DataFrame = {
- ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, session)
+ ArrowConverters.toDataFrame(arrowBatches, schemaString, session)
}
def explainString(queryExecution: QueryExecution, mode: String): String = {
@@ -137,16 +137,16 @@ private[sql] object PythonSQLUtils extends Logging {
}
/**
- * Helper for making a dataframe from arrow data from data sent from python over a socket. This is
+ * Helper for making a dataframe from Arrow data from data sent from python over a socket. This is
* used when encryption is enabled, and we don't want to write data to a file.
*/
-private[sql] class ArrowRDDServer(session: SparkSession) extends PythonRDDServer {
-
- override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
- // Create array to consume iterator so that we can safely close the inputStream
- val batches = ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray
- // Parallelize the record batches to create an RDD
- JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length))
+private[spark] class ArrowIteratorServer
+ extends SocketAuthServer[Iterator[Array[Byte]]]("pyspark-arrow-batches-server") {
+
+ def handleConnection(sock: Socket): Iterator[Array[Byte]] = {
+ val in = sock.getInputStream()
+ val dechunkedInput: InputStream = new DechunkedInputStream(in)
+ // Create array to consume iterator so that we can safely close the file
+ ArrowConverters.getBatchesFromStream(Channels.newChannel(dechunkedInput)).toArray.iterator
}
-
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 7831ddee4f9..f58afcfa05d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -23,6 +23,7 @@ import java.util.{Locale, Map => JMap}
import scala.collection.JavaConverters._
import scala.util.matching.Regex
+import org.apache.spark.TaskContext
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.broadcast.Broadcast
@@ -230,7 +231,9 @@ private[sql] object SQLUtils extends Logging {
def readArrowStreamFromFile(
sparkSession: SparkSession,
filename: String): JavaRDD[Array[Byte]] = {
- ArrowConverters.readArrowStreamFromFile(sparkSession, filename)
+ // Parallelize the record batches to create an RDD
+ val batches = ArrowConverters.readArrowStreamFromFile(filename)
+ JavaRDD.fromRDD(sparkSession.sparkContext.parallelize(batches, batches.length))
}
/**
@@ -241,6 +244,11 @@ private[sql] object SQLUtils extends Logging {
arrowBatchRDD: JavaRDD[Array[Byte]],
schema: StructType,
sparkSession: SparkSession): DataFrame = {
- ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession)
+ val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
+ val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
+ val context = TaskContext.get()
+ ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
+ }
+ sparkSession.internalCreateDataFrame(rdd.setName("arrow"), schema)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 93ff276529d..bded158645c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -29,10 +29,12 @@ import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel
import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer}
import org.apache.spark.TaskContext
-import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
@@ -68,7 +70,7 @@ private[sql] class ArrowBatchStreamWriter(
}
}
-private[sql] object ArrowConverters {
+private[sql] object ArrowConverters extends Logging {
/**
* Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
@@ -143,7 +145,7 @@ private[sql] object ArrowConverters {
new Iterator[InternalRow] {
private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty
- context.addTaskCompletionListener[Unit] { _ =>
+ if (context != null) context.addTaskCompletionListener[Unit] { _ =>
root.close()
allocator.close()
}
@@ -190,32 +192,54 @@ private[sql] object ArrowConverters {
}
/**
- * Create a DataFrame from an RDD of serialized ArrowRecordBatches.
+ * Create a DataFrame from an iterator of serialized ArrowRecordBatches.
*/
- private[sql] def toDataFrame(
- arrowBatchRDD: JavaRDD[Array[Byte]],
+ /**
+ * Create a DataFrame from an iterator of serialized ArrowRecordBatches.
+ */
+ def toDataFrame(
+ arrowBatches: Iterator[Array[Byte]],
schemaString: String,
session: SparkSession): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
- val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
- val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
- val context = TaskContext.get()
- ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
+ val attrs = schema.toAttributes
+ val batchesInDriver = arrowBatches.toArray
+ val shouldUseRDD = session.sessionState.conf
+ .arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum
+
+ if (shouldUseRDD) {
+ logDebug("Using RDD-based createDataFrame with Arrow optimization.")
+ val timezone = session.sessionState.conf.sessionLocalTimeZone
+ val rdd = session.sparkContext.parallelize(batchesInDriver, batchesInDriver.length)
+ .mapPartitions { batchesInExecutors =>
+ ArrowConverters.fromBatchIterator(
+ batchesInExecutors,
+ schema,
+ timezone,
+ TaskContext.get())
+ }
+ session.internalCreateDataFrame(rdd.setName("arrow"), schema)
+ } else {
+ logDebug("Using LocalRelation in createDataFrame with Arrow optimization.")
+ val data = ArrowConverters.fromBatchIterator(
+ batchesInDriver.toIterator,
+ schema,
+ session.sessionState.conf.sessionLocalTimeZone,
+ TaskContext.get())
+
+ // Project/copy it. Otherwise, the Arrow column vectors will be closed and released out.
+ val proj = UnsafeProjection.create(attrs, attrs)
+ Dataset.ofRows(session, LocalRelation(attrs, data.map(r => proj(r).copy()).toArray))
}
- session.internalCreateDataFrame(rdd.setName("arrow"), schema)
}
/**
- * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches.
+ * Read a file as an Arrow stream and return an array of serialized ArrowRecordBatches.
*/
- private[sql] def readArrowStreamFromFile(
- session: SparkSession,
- filename: String): JavaRDD[Array[Byte]] = {
+ private[sql] def readArrowStreamFromFile(filename: String): Array[Array[Byte]] = {
Utils.tryWithResource(new FileInputStream(filename)) { fileStream =>
// Create array to consume iterator so that we can safely close the file
- val batches = getBatchesFromStream(fileStream.getChannel).toArray
- // Parallelize the record batches to create an RDD
- JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length))
+ getBatchesFromStream(fileStream.getChannel).toArray
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org