You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/08 09:41:49 UTC
git commit: [SPARK-2391][SQL] Custom take() for LIMIT queries.
Repository: spark
Updated Branches:
refs/heads/master 3cd5029be -> 5a4063645
[SPARK-2391][SQL] Custom take() for LIMIT queries.
Using Spark's take can result in an entire in-memory partition to be shipped in order to retrieve a single row.
Author: Michael Armbrust <mi...@databricks.com>
Closes #1318 from marmbrus/takeLimit and squashes the following commits:
77289a5 [Michael Armbrust] Update scala doc
32f0674 [Michael Armbrust] Custom take implementation for LIMIT queries.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5a406364
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5a406364
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5a406364
Branch: refs/heads/master
Commit: 5a4063645dd7bb4cd8bda890785235729804ab09
Parents: 3cd5029
Author: Michael Armbrust <mi...@databricks.com>
Authored: Tue Jul 8 00:41:46 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Tue Jul 8 00:41:46 2014 -0700
----------------------------------------------------------------------
.../spark/sql/execution/basicOperators.scala | 51 ++++++++++++++++++--
1 file changed, 47 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5a406364/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e8816f0..97abd63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.annotation.DeveloperApi
@@ -83,9 +84,9 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex
* :: DeveloperApi ::
* Take the first limit elements. Note that the implementation is different depending on whether
* this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
- * this operator uses Spark's take method on the Spark driver. If it is not terminal or is
- * invoked using execute, we first take the limit on each partition, and then repartition all the
- * data to a single partition to compute the global limit.
+ * this operator uses something similar to Spark's take method on the Spark driver. If it is not
+ * terminal or is invoked using execute, we first take the limit on each partition, and then
+ * repartition all the data to a single partition to compute the global limit.
*/
@DeveloperApi
case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
@@ -97,7 +98,49 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext
override def output = child.output
- override def executeCollect() = child.execute().map(_.copy()).take(limit)
+ /**
+ * A custom implementation modeled after the take function on RDDs but which never runs any job
+ * locally. This is to avoid shipping an entire partition of data in order to retrieve only a few
+ * rows.
+ */
+ override def executeCollect(): Array[Row] = {
+ if (limit == 0) {
+ return new Array[Row](0)
+ }
+
+ val childRDD = child.execute().map(_.copy())
+
+ val buf = new ArrayBuffer[Row]
+ val totalParts = childRDD.partitions.length
+ var partsScanned = 0
+ while (buf.size < limit && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (buf.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
+ val left = limit - buf.size
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val sc = sqlContext.sparkContext
+ val res =
+ sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)
+
+ res.foreach(buf ++= _.take(limit - buf.size))
+ partsScanned += numPartsToTry
+ }
+
+ buf.toArray
+ }
override def execute() = {
val rdd = child.execute().mapPartitions { iter =>