You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2019/09/20 17:00:23 UTC

[spark] branch master updated: [SPARK-27659][PYTHON] Allow PySpark to prefetch during toLocalIterator

This is an automated email from the ASF dual-hosted git repository.

holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 42050c3  [SPARK-27659][PYTHON] Allow PySpark to prefetch during toLocalIterator
42050c3 is described below

commit 42050c3f4f21adaa14808e474a0db69f62671935
Author: Holden Karau <hk...@apple.com>
AuthorDate: Fri Sep 20 09:59:31 2019 -0700

    [SPARK-27659][PYTHON] Allow PySpark to prefetch during toLocalIterator
    
    ### What changes were proposed in this pull request?
    
    This PR allows Python toLocalIterator to prefetch the next partition while the first partition is being collected. The PR also adds a demo micro bench mark in the examples directory, we may wish to keep this or not.
    
    ### Why are the changes needed?
    
    In https://issues.apache.org/jira/browse/SPARK-23961 / 5e79ae3b40b76e3473288830ab958fc4834dcb33 we changed PySpark to only pull one partition at a time. This is memory efficient, but if partitions take time to compute this can mean we're spending more time blocking.
    
    ### Does this PR introduce any user-facing change?
    
    A new param is added to toLocalIterator
    
    ### How was this patch tested?
    
    New unit test inside of `test_rdd.py` checks the time that the elements are evaluated at. Another test that the results remain the same are added to `test_dataframe.py`.
    
    I also ran a micro benchmark in the examples directory `prefetch.py` which shows an improvement of ~40% in this specific use case.
    
    >
    > 19/08/16 17:11:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
    > Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
    > Setting default log level to "WARN".
    > To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
    > Running timers:
    >
    > [Stage 32:>                                                         (0 + 1) / 1]
    > Results:
    >
    > Prefetch time:
    >
    > 100.228110831
    >
    >
    > Regular time:
    >
    > 188.341721614
    >
    >
    >
    
    Closes #25515 from holdenk/SPARK-27659-allow-pyspark-tolocalitr-to-prefetch.
    
    Authored-by: Holden Karau <hk...@apple.com>
    Signed-off-by: Holden Karau <hk...@apple.com>
---
 .../org/apache/spark/api/python/PythonRDD.scala    | 21 +++++++++++++++++----
 python/pyspark/rdd.py                              | 10 ++++++++--
 python/pyspark/sql/dataframe.py                    |  8 ++++++--
 python/pyspark/sql/tests/test_dataframe.py         |  6 ++++++
 python/pyspark/tests/test_rdd.py                   | 22 ++++++++++++++++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  4 ++--
 6 files changed, 61 insertions(+), 10 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 4d76ff7..7cbfb71 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.concurrent.duration.Duration
 import scala.reflect.ClassTag
 
 import org.apache.hadoop.conf.Configuration
@@ -179,15 +180,22 @@ private[spark] object PythonRDD extends Logging {
    *         data collected from this job, the secret for authentication, and a socket auth
    *         server object that can be used to join the JVM serving thread in Python.
    */
-  def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
+  def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = {
     val handleFunc = (sock: Socket) => {
       val out = new DataOutputStream(sock.getOutputStream)
       val in = new DataInputStream(sock.getInputStream)
       Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
         // Collects a partition on each iteration
         val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
-          rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
+          var result: Array[Any] = null
+          rdd.sparkContext.submitJob(
+            rdd,
+            (iter: Iterator[Any]) => iter.toArray,
+            Seq(i), // The partition we are evaluating
+            (_, res: Array[Any]) => result = res,
+            result)
         }
+        val prefetchIter = collectPartitionIter.buffered
 
         // Write data until iteration is complete, client stops iteration, or error occurs
         var complete = false
@@ -196,10 +204,15 @@ private[spark] object PythonRDD extends Logging {
           // Read request for data, value of zero will stop iteration or non-zero to continue
           if (in.readInt() == 0) {
             complete = true
-          } else if (collectPartitionIter.hasNext) {
+          } else if (prefetchIter.hasNext) {
 
             // Client requested more data, attempt to collect the next partition
-            val partitionArray = collectPartitionIter.next()
+            val partitionFuture = prefetchIter.next()
+            // Cause the next job to be submitted if prefetchPartitions is enabled.
+            if (prefetchPartitions) {
+              prefetchIter.headOption
+            }
+            val partitionArray = ThreadUtils.awaitResult(partitionFuture, Duration.Inf)
 
             // Send response there is a partition to read
             out.writeInt(1)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index be0244b7..1edffaa 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2437,17 +2437,23 @@ class RDD(object):
         hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
         return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)
 
-    def toLocalIterator(self):
+    def toLocalIterator(self, prefetchPartitions=False):
         """
         Return an iterator that contains all of the elements in this RDD.
         The iterator will consume as much memory as the largest partition in this RDD.
+        With prefetch it may consume up to the memory of the 2 largest partitions.
+
+        :param prefetchPartitions: If Spark should pre-fetch the next partition
+                                   before it is needed.
 
         >>> rdd = sc.parallelize(range(10))
         >>> [x for x in rdd.toLocalIterator()]
         [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
         """
         with SCCallSiteSync(self.context) as css:
-            sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
+            sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(
+                self._jrdd.rdd(),
+                prefetchPartitions)
         return _local_iterator_from_socket(sock_info, self._jrdd_deserializer)
 
     def barrier(self):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3984712..03b37fa 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -520,16 +520,20 @@ class DataFrame(object):
 
     @ignore_unicode_prefix
     @since(2.0)
-    def toLocalIterator(self):
+    def toLocalIterator(self, prefetchPartitions=False):
         """
         Returns an iterator that contains all of the rows in this :class:`DataFrame`.
         The iterator will consume as much memory as the largest partition in this DataFrame.
+        With prefetch it may consume up to the memory of the 2 largest partitions.
+
+        :param prefetchPartitions: If Spark should pre-fetch the next partition
+                                   before it is needed.
 
         >>> list(df.toLocalIterator())
         [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
         """
         with SCCallSiteSync(self._sc) as css:
-            sock_info = self._jdf.toPythonIterator()
+            sock_info = self._jdf.toPythonIterator(prefetchPartitions)
         return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
 
     @ignore_unicode_prefix
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index bc4ee88..90a5415 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -690,6 +690,12 @@ class DataFrameTests(ReusedSQLTestCase):
         expected = df.collect()
         self.assertEqual(expected, list(it))
 
+    def test_to_local_iterator_prefetch(self):
+        df = self.spark.range(8, numPartitions=4)
+        expected = df.collect()
+        it = df.toLocalIterator(prefetchPartitions=True)
+        self.assertEqual(expected, list(it))
+
     def test_to_local_iterator_not_fully_consumed(self):
         # SPARK-23961: toLocalIterator throws exception when not fully consumed
         # Create a DataFrame large enough so that write to socket will eventually block
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
index bff0803..e7a7971 100644
--- a/python/pyspark/tests/test_rdd.py
+++ b/python/pyspark/tests/test_rdd.py
@@ -14,11 +14,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+from datetime import datetime, timedelta
 import hashlib
 import os
 import random
 import sys
 import tempfile
+import time
 from glob import glob
 
 from py4j.protocol import Py4JJavaError
@@ -68,6 +70,26 @@ class RDDTests(ReusedPySparkTestCase):
         it2 = rdd2.toLocalIterator()
         self.assertEqual([1, 2, 3], sorted(it2))
 
+    def test_to_localiterator_prefetch(self):
+        # Test that we fetch the next partition in parallel
+        # We do this by returning the current time and:
+        # reading the first elem, waiting, and reading the second elem
+        # If not in parallel then these would be at different times
+        # But since they are being computed in parallel we see the time
+        # is "close enough" to the same.
+        rdd = self.sc.parallelize(range(2), 2)
+        times1 = rdd.map(lambda x: datetime.now())
+        times2 = rdd.map(lambda x: datetime.now())
+        times_iter_prefetch = times1.toLocalIterator(prefetchPartitions=True)
+        times_iter = times2.toLocalIterator(prefetchPartitions=False)
+        times_prefetch_head = next(times_iter_prefetch)
+        times_head = next(times_iter)
+        time.sleep(2)
+        times_next = next(times_iter)
+        times_prefetch_next = next(times_iter_prefetch)
+        self.assertTrue(times_next - times_head >= timedelta(seconds=2))
+        self.assertTrue(times_prefetch_next - times_prefetch_head < timedelta(seconds=1))
+
     def test_save_as_textfile_with_unicode(self):
         # Regression test for SPARK-970
         x = u"\u00A1Hola, mundo!"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a2f5f03..9a2d800 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3356,9 +3356,9 @@ class Dataset[T] private[sql](
     }
   }
 
-  private[sql] def toPythonIterator(): Array[Any] = {
+  private[sql] def toPythonIterator(prefetchPartitions: Boolean = false): Array[Any] = {
     withNewExecutionId {
-      PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
+      PythonRDD.toLocalIteratorAndServe(javaToPython.rdd, prefetchPartitions)
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org