You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/08 09:41:49 UTC

git commit: [SPARK-2391][SQL] Custom take() for LIMIT queries.

Repository: spark
Updated Branches:
  refs/heads/master 3cd5029be -> 5a4063645


[SPARK-2391][SQL] Custom take() for LIMIT queries.

Using Spark's take can result in an entire in-memory partition to be shipped in order to retrieve a single row.

Author: Michael Armbrust <mi...@databricks.com>

Closes #1318 from marmbrus/takeLimit and squashes the following commits:

77289a5 [Michael Armbrust] Update scala doc
32f0674 [Michael Armbrust] Custom take implementation for LIMIT queries.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5a406364
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5a406364
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5a406364

Branch: refs/heads/master
Commit: 5a4063645dd7bb4cd8bda890785235729804ab09
Parents: 3cd5029
Author: Michael Armbrust <mi...@databricks.com>
Authored: Tue Jul 8 00:41:46 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Tue Jul 8 00:41:46 2014 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/basicOperators.scala    | 51 ++++++++++++++++++--
 1 file changed, 47 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5a406364/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e8816f0..97abd63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution
 
+import scala.collection.mutable.ArrayBuffer
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.annotation.DeveloperApi
@@ -83,9 +84,9 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex
  * :: DeveloperApi ::
  * Take the first limit elements. Note that the implementation is different depending on whether
  * this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
- * this operator uses Spark's take method on the Spark driver. If it is not terminal or is
- * invoked using execute, we first take the limit on each partition, and then repartition all the
- * data to a single partition to compute the global limit.
+ * this operator uses something similar to Spark's take method on the Spark driver. If it is not
+ * terminal or is invoked using execute, we first take the limit on each partition, and then
+ * repartition all the data to a single partition to compute the global limit.
  */
 @DeveloperApi
 case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
@@ -97,7 +98,49 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext
 
   override def output = child.output
 
-  override def executeCollect() = child.execute().map(_.copy()).take(limit)
+  /**
+   * A custom implementation modeled after the take function on RDDs but which never runs any job
+   * locally.  This is to avoid shipping an entire partition of data in order to retrieve only a few
+   * rows.
+   */
+  override def executeCollect(): Array[Row] = {
+    if (limit == 0) {
+      return new Array[Row](0)
+    }
+
+    val childRDD = child.execute().map(_.copy())
+
+    val buf = new ArrayBuffer[Row]
+    val totalParts = childRDD.partitions.length
+    var partsScanned = 0
+    while (buf.size < limit && partsScanned < totalParts) {
+      // The number of partitions to try in this iteration. It is ok for this number to be
+      // greater than totalParts because we actually cap it at totalParts in runJob.
+      var numPartsToTry = 1
+      if (partsScanned > 0) {
+        // If we didn't find any rows after the first iteration, just try all partitions next.
+        // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+        // by 50%.
+        if (buf.size == 0) {
+          numPartsToTry = totalParts - 1
+        } else {
+          numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt
+        }
+      }
+      numPartsToTry = math.max(0, numPartsToTry)  // guard against negative num of partitions
+
+      val left = limit - buf.size
+      val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+      val sc = sqlContext.sparkContext
+      val res =
+        sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)
+
+      res.foreach(buf ++= _.take(limit - buf.size))
+      partsScanned += numPartsToTry
+    }
+
+    buf.toArray
+  }
 
   override def execute() = {
     val rdd = child.execute().mapPartitions { iter =>