You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/09/14 17:09:58 UTC

spark git commit: [SPARK-17514] df.take(1) and df.limit(1).collect() should perform the same in Python

Repository: spark
Updated Branches:
  refs/heads/master 52738d4e0 -> 6d06ff6f7


[SPARK-17514] df.take(1) and df.limit(1).collect() should perform the same in Python

## What changes were proposed in this pull request?

In PySpark, `df.take(1)` runs a single-stage job which computes only one partition of the DataFrame, while `df.limit(1).collect()` computes all partitions and runs a two-stage job. This difference in performance is confusing.

The reason why `limit(1).collect()` is so much slower is that `collect()` internally maps to `df.rdd.<some-pyspark-conversions>.toLocalIterator`, which causes Spark SQL to build a query where a global limit appears in the middle of the plan; this, in turn, ends up being executed inefficiently because limits in the middle of plans are now implemented by repartitioning to a single task rather than by running a `take()` job on the driver (this was done in #7334, a patch which was a prerequisite to allowing partition-local limits to be pushed beneath unions, etc.).

In order to fix this performance problem I think that we should generalize the fix from SPARK-10731 / #8876 so that `DataFrame.collect()` also delegates to the Scala implementation and shares the same performance properties. This patch modifies `DataFrame.collect()` to first collect all results to the driver and then pass them to Python, allowing this query to be planned using Spark's `CollectLimit` optimizations.

## How was this patch tested?

Added a regression test in `sql/tests.py` which asserts that the expected number of jobs, stages, and tasks are run for both queries.

Author: Josh Rosen <jo...@databricks.com>

Closes #15068 from JoshRosen/pyspark-collect-limit.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6d06ff6f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6d06ff6f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6d06ff6f

Branch: refs/heads/master
Commit: 6d06ff6f7e2dd72ba8fe96cd875e83eda6ebb2a9
Parents: 52738d4
Author: Josh Rosen <jo...@databricks.com>
Authored: Wed Sep 14 10:10:01 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Sep 14 10:10:01 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                   |  5 +----
 python/pyspark/sql/tests.py                       | 18 ++++++++++++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala |  8 ++++++--
 .../sql/execution/python/EvaluatePython.scala     | 13 +------------
 4 files changed, 26 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6d06ff6f/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index e5eac91..0f7d8fb 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -357,10 +357,7 @@ class DataFrame(object):
         >>> df.take(2)
         [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
         """
-        with SCCallSiteSync(self._sc) as css:
-            port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe(
-                self._jdf, num)
-        return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
+        return self.limit(num).collect()
 
     @since(1.3)
     def foreach(self, f):

http://git-wip-us.apache.org/repos/asf/spark/blob/6d06ff6f/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 769e454..1be0b72 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1862,6 +1862,24 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
             sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
             ["1", "2", "2", "2"])
 
+    def test_limit_and_take(self):
+        df = self.spark.range(1, 1000, numPartitions=10)
+
+        def assert_runs_only_one_job_stage_and_task(job_group_name, f):
+            tracker = self.sc.statusTracker()
+            self.sc.setJobGroup(job_group_name, description="")
+            f()
+            jobs = tracker.getJobIdsForGroup(job_group_name)
+            self.assertEqual(1, len(jobs))
+            stages = tracker.getJobInfo(jobs[0]).stageIds
+            self.assertEqual(1, len(stages))
+            self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks)
+
+        # Regression test for SPARK-10731: take should delegate to Scala implementation
+        assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1))
+        # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n)
+        assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect())
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests import *

http://git-wip-us.apache.org/repos/asf/spark/blob/6d06ff6f/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3b3cb82..9cfbdff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -29,7 +29,7 @@ import org.apache.commons.lang3.StringUtils
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.java.function._
-import org.apache.spark.api.python.PythonRDD
+import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst._
@@ -2567,8 +2567,12 @@ class Dataset[T] private[sql](
   }
 
   private[sql] def collectToPython(): Int = {
+    EvaluatePython.registerPicklers()
     withNewExecutionId {
-      PythonRDD.collectAndServe(javaToPython.rdd)
+      val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
+      val iter = new SerDeUtil.AutoBatchedPickler(
+        queryExecution.executedPlan.executeCollect().iterator.map(toJava))
+      PythonRDD.serveIterator(iter, "serve-DataFrame")
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6d06ff6f/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index cf68ed4..724025b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -24,9 +24,8 @@ import scala.collection.JavaConverters._
 
 import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler}
 
-import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
+import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
@@ -34,16 +33,6 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 object EvaluatePython {
-  def takeAndServe(df: DataFrame, n: Int): Int = {
-    registerPicklers()
-    df.withNewExecutionId {
-      val iter = new SerDeUtil.AutoBatchedPickler(
-        df.queryExecution.executedPlan.executeTake(n).iterator.map { row =>
-          EvaluatePython.toJava(row, df.schema)
-        })
-      PythonRDD.serveIterator(iter, s"serve-DataFrame")
-    }
-  }
 
   def needConversionInPython(dt: DataType): Boolean = dt match {
     case DateType | TimestampType => true


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org