You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2019/09/20 17:00:23 UTC
[spark] branch master updated: [SPARK-27659][PYTHON] Allow PySpark
to prefetch during toLocalIterator
This is an automated email from the ASF dual-hosted git repository.
holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 42050c3 [SPARK-27659][PYTHON] Allow PySpark to prefetch during toLocalIterator
42050c3 is described below
commit 42050c3f4f21adaa14808e474a0db69f62671935
Author: Holden Karau <hk...@apple.com>
AuthorDate: Fri Sep 20 09:59:31 2019 -0700
[SPARK-27659][PYTHON] Allow PySpark to prefetch during toLocalIterator
### What changes were proposed in this pull request?
This PR allows Python toLocalIterator to prefetch the next partition while the first partition is being collected. The PR also adds a demo micro bench mark in the examples directory, we may wish to keep this or not.
### Why are the changes needed?
In https://issues.apache.org/jira/browse/SPARK-23961 / 5e79ae3b40b76e3473288830ab958fc4834dcb33 we changed PySpark to only pull one partition at a time. This is memory efficient, but if partitions take time to compute this can mean we're spending more time blocking.
### Does this PR introduce any user-facing change?
A new param is added to toLocalIterator
### How was this patch tested?
New unit test inside of `test_rdd.py` checks the time that the elements are evaluated at. Another test that the results remain the same are added to `test_dataframe.py`.
I also ran a micro benchmark in the examples directory `prefetch.py` which shows an improvement of ~40% in this specific use case.
>
> 19/08/16 17:11:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
> Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
> Setting default log level to "WARN".
> To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
> Running timers:
>
> [Stage 32:> (0 + 1) / 1]
> Results:
>
> Prefetch time:
>
> 100.228110831
>
>
> Regular time:
>
> 188.341721614
>
>
>
Closes #25515 from holdenk/SPARK-27659-allow-pyspark-tolocalitr-to-prefetch.
Authored-by: Holden Karau <hk...@apple.com>
Signed-off-by: Holden Karau <hk...@apple.com>
---
.../org/apache/spark/api/python/PythonRDD.scala | 21 +++++++++++++++++----
python/pyspark/rdd.py | 10 ++++++++--
python/pyspark/sql/dataframe.py | 8 ++++++--
python/pyspark/sql/tests/test_dataframe.py | 6 ++++++
python/pyspark/tests/test_rdd.py | 22 ++++++++++++++++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 4 ++--
6 files changed, 61 insertions(+), 10 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 4d76ff7..7cbfb71 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
@@ -179,15 +180,22 @@ private[spark] object PythonRDD extends Logging {
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
- def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
+ def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = {
val handleFunc = (sock: Socket) => {
val out = new DataOutputStream(sock.getOutputStream)
val in = new DataInputStream(sock.getInputStream)
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
- rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
+ var result: Array[Any] = null
+ rdd.sparkContext.submitJob(
+ rdd,
+ (iter: Iterator[Any]) => iter.toArray,
+ Seq(i), // The partition we are evaluating
+ (_, res: Array[Any]) => result = res,
+ result)
}
+ val prefetchIter = collectPartitionIter.buffered
// Write data until iteration is complete, client stops iteration, or error occurs
var complete = false
@@ -196,10 +204,15 @@ private[spark] object PythonRDD extends Logging {
// Read request for data, value of zero will stop iteration or non-zero to continue
if (in.readInt() == 0) {
complete = true
- } else if (collectPartitionIter.hasNext) {
+ } else if (prefetchIter.hasNext) {
// Client requested more data, attempt to collect the next partition
- val partitionArray = collectPartitionIter.next()
+ val partitionFuture = prefetchIter.next()
+ // Cause the next job to be submitted if prefetchPartitions is enabled.
+ if (prefetchPartitions) {
+ prefetchIter.headOption
+ }
+ val partitionArray = ThreadUtils.awaitResult(partitionFuture, Duration.Inf)
// Send response there is a partition to read
out.writeInt(1)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index be0244b7..1edffaa 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2437,17 +2437,23 @@ class RDD(object):
hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)
- def toLocalIterator(self):
+ def toLocalIterator(self, prefetchPartitions=False):
"""
Return an iterator that contains all of the elements in this RDD.
The iterator will consume as much memory as the largest partition in this RDD.
+ With prefetch it may consume up to the memory of the 2 largest partitions.
+
+ :param prefetchPartitions: If Spark should pre-fetch the next partition
+ before it is needed.
>>> rdd = sc.parallelize(range(10))
>>> [x for x in rdd.toLocalIterator()]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
with SCCallSiteSync(self.context) as css:
- sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
+ sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(
+ self._jrdd.rdd(),
+ prefetchPartitions)
return _local_iterator_from_socket(sock_info, self._jrdd_deserializer)
def barrier(self):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3984712..03b37fa 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -520,16 +520,20 @@ class DataFrame(object):
@ignore_unicode_prefix
@since(2.0)
- def toLocalIterator(self):
+ def toLocalIterator(self, prefetchPartitions=False):
"""
Returns an iterator that contains all of the rows in this :class:`DataFrame`.
The iterator will consume as much memory as the largest partition in this DataFrame.
+ With prefetch it may consume up to the memory of the 2 largest partitions.
+
+ :param prefetchPartitions: If Spark should pre-fetch the next partition
+ before it is needed.
>>> list(df.toLocalIterator())
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- sock_info = self._jdf.toPythonIterator()
+ sock_info = self._jdf.toPythonIterator(prefetchPartitions)
return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
@ignore_unicode_prefix
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index bc4ee88..90a5415 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -690,6 +690,12 @@ class DataFrameTests(ReusedSQLTestCase):
expected = df.collect()
self.assertEqual(expected, list(it))
+ def test_to_local_iterator_prefetch(self):
+ df = self.spark.range(8, numPartitions=4)
+ expected = df.collect()
+ it = df.toLocalIterator(prefetchPartitions=True)
+ self.assertEqual(expected, list(it))
+
def test_to_local_iterator_not_fully_consumed(self):
# SPARK-23961: toLocalIterator throws exception when not fully consumed
# Create a DataFrame large enough so that write to socket will eventually block
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
index bff0803..e7a7971 100644
--- a/python/pyspark/tests/test_rdd.py
+++ b/python/pyspark/tests/test_rdd.py
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from datetime import datetime, timedelta
import hashlib
import os
import random
import sys
import tempfile
+import time
from glob import glob
from py4j.protocol import Py4JJavaError
@@ -68,6 +70,26 @@ class RDDTests(ReusedPySparkTestCase):
it2 = rdd2.toLocalIterator()
self.assertEqual([1, 2, 3], sorted(it2))
+ def test_to_localiterator_prefetch(self):
+ # Test that we fetch the next partition in parallel
+ # We do this by returning the current time and:
+ # reading the first elem, waiting, and reading the second elem
+ # If not in parallel then these would be at different times
+ # But since they are being computed in parallel we see the time
+ # is "close enough" to the same.
+ rdd = self.sc.parallelize(range(2), 2)
+ times1 = rdd.map(lambda x: datetime.now())
+ times2 = rdd.map(lambda x: datetime.now())
+ times_iter_prefetch = times1.toLocalIterator(prefetchPartitions=True)
+ times_iter = times2.toLocalIterator(prefetchPartitions=False)
+ times_prefetch_head = next(times_iter_prefetch)
+ times_head = next(times_iter)
+ time.sleep(2)
+ times_next = next(times_iter)
+ times_prefetch_next = next(times_iter_prefetch)
+ self.assertTrue(times_next - times_head >= timedelta(seconds=2))
+ self.assertTrue(times_prefetch_next - times_prefetch_head < timedelta(seconds=1))
+
def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a2f5f03..9a2d800 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3356,9 +3356,9 @@ class Dataset[T] private[sql](
}
}
- private[sql] def toPythonIterator(): Array[Any] = {
+ private[sql] def toPythonIterator(prefetchPartitions: Boolean = false): Array[Any] = {
withNewExecutionId {
- PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
+ PythonRDD.toLocalIteratorAndServe(javaToPython.rdd, prefetchPartitions)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org