You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2013/09/27 00:01:55 UTC

[05/11] git commit: Merge pull request #9 from rxin/limit

Merge pull request #9 from rxin/limit

Smarter take/limit implementation.

(cherry picked from commit 6566a19b38204d754c5e8f821b4276616e90abc6)
Signed-off-by: Reynold Xin <rx...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/240ca935
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/240ca935
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/240ca935

Branch: refs/heads/branch-0.8
Commit: 240ca93539b015cb2556ec80e3cc21afd8ed368f
Parents: b1d0832
Author: Patrick Wendell <pw...@gmail.com>
Authored: Thu Sep 26 08:01:04 2013 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Thu Sep 26 13:12:06 2013 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 38 ++++++++++++++------
 .../scala/org/apache/spark/rdd/RDDSuite.scala   | 38 ++++++++++++++++++++
 2 files changed, 66 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/240ca935/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 1082cba..1893627 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -753,24 +753,42 @@ abstract class RDD[T: ClassManifest](
   }
 
   /**
-   * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
-   * it will be slow if a lot of partitions are required. In that case, use collect() to get the
-   * whole RDD instead.
+   * Take the first num elements of the RDD. It works by first scanning one partition, and use the
+   * results from that partition to estimate the number of additional partitions needed to satisfy
+   * the limit.
    */
   def take(num: Int): Array[T] = {
     if (num == 0) {
       return new Array[T](0)
     }
+
     val buf = new ArrayBuffer[T]
-    var p = 0
-    while (buf.size < num && p < partitions.size) {
+    val totalParts = this.partitions.length
+    var partsScanned = 0
+    while (buf.size < num && partsScanned < totalParts) {
+      // The number of partitions to try in this iteration. It is ok for this number to be
+      // greater than totalParts because we actually cap it at totalParts in runJob.
+      var numPartsToTry = 1
+      if (partsScanned > 0) {
+        // If we didn't find any rows after the first iteration, just try all partitions next.
+        // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+        // by 50%.
+        if (buf.size == 0) {
+          numPartsToTry = totalParts - 1
+        } else {
+          numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
+        }
+      }
+      numPartsToTry = math.max(0, numPartsToTry)  // guard against negative num of partitions
+
       val left = num - buf.size
-      val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true)
-      buf ++= res(0)
-      if (buf.size == num)
-        return buf.toArray
-      p += 1
+      val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+      val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true)
+
+      res.foreach(buf ++= _.take(num - buf.size))
+      partsScanned += numPartsToTry
     }
+
     return buf.toArray
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/240ca935/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index c1df5e1..63adf1c 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -321,6 +321,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
   }
 
+  test("take") {
+    var nums = sc.makeRDD(Range(1, 1000), 1)
+    assert(nums.take(0).size === 0)
+    assert(nums.take(1) === Array(1))
+    assert(nums.take(3) === Array(1, 2, 3))
+    assert(nums.take(500) === (1 to 500).toArray)
+    assert(nums.take(501) === (1 to 501).toArray)
+    assert(nums.take(999) === (1 to 999).toArray)
+    assert(nums.take(1000) === (1 to 999).toArray)
+
+    nums = sc.makeRDD(Range(1, 1000), 2)
+    assert(nums.take(0).size === 0)
+    assert(nums.take(1) === Array(1))
+    assert(nums.take(3) === Array(1, 2, 3))
+    assert(nums.take(500) === (1 to 500).toArray)
+    assert(nums.take(501) === (1 to 501).toArray)
+    assert(nums.take(999) === (1 to 999).toArray)
+    assert(nums.take(1000) === (1 to 999).toArray)
+
+    nums = sc.makeRDD(Range(1, 1000), 100)
+    assert(nums.take(0).size === 0)
+    assert(nums.take(1) === Array(1))
+    assert(nums.take(3) === Array(1, 2, 3))
+    assert(nums.take(500) === (1 to 500).toArray)
+    assert(nums.take(501) === (1 to 501).toArray)
+    assert(nums.take(999) === (1 to 999).toArray)
+    assert(nums.take(1000) === (1 to 999).toArray)
+
+    nums = sc.makeRDD(Range(1, 1000), 1000)
+    assert(nums.take(0).size === 0)
+    assert(nums.take(1) === Array(1))
+    assert(nums.take(3) === Array(1, 2, 3))
+    assert(nums.take(500) === (1 to 500).toArray)
+    assert(nums.take(501) === (1 to 501).toArray)
+    assert(nums.take(999) === (1 to 999).toArray)
+    assert(nums.take(1000) === (1 to 999).toArray)
+  }
+
   test("top with predefined ordering") {
     val nums = Array.range(1, 100000)
     val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)