You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/06/13 04:44:42 UTC

git commit: SPARK-1939 Refactor takeSample method in RDD to use ScaSRS

Repository: spark
Updated Branches:
  refs/heads/master 0154587ab -> 1de1d703b


SPARK-1939 Refactor takeSample method in RDD to use ScaSRS

Modified the takeSample method in RDD to use the ScaSRS sampling technique to improve performance. Added a private method that computes sampling rate > sample_size/total to ensure sufficient sample size with success rate >= 0.9999. Added a unit test for the private method to validate choice of sampling rate.

Author: Doris Xin <do...@gmail.com>
Author: dorx <do...@gmail.com>
Author: Xiangrui Meng <me...@databricks.com>

Closes #916 from dorx/takeSample and squashes the following commits:

5b061ae [Doris Xin] merge master
444e750 [Doris Xin] edge cases
3de882b [dorx] Merge pull request #2 from mengxr/SPARK-1939
82dde31 [Xiangrui Meng] update pyspark's takeSample
48d954d [Doris Xin] remove unused imports from RDDSuite
fb1452f [Doris Xin] allowing num to be greater than count in all cases
1481b01 [Doris Xin] washing test tubes and making coffee
dc699f3 [Doris Xin] give back imports removed by accident in rdd.py
64e445b [Doris Xin] logwarnning as soon as it enters the while loop
55518ed [Doris Xin] added TODO for logging in rdd.py
eff89e2 [Doris Xin] addressed reviewer comments.
ecab508 [Doris Xin] "fixed checkstyle violation
0a9b3e3 [Doris Xin] "reviewer comment addressed"
f80f270 [Doris Xin] Merge branch 'master' into takeSample
ae3ad04 [Doris Xin] fixed edge cases to prevent overflow
065ebcd [Doris Xin] Merge branch 'master' into takeSample
9bdd36e [Doris Xin] Check sample size and move computeFraction
e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample
7cab53a [Doris Xin] fixed import bug in rdd.py
ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD
1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1de1d703
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1de1d703
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1de1d703

Branch: refs/heads/master
Commit: 1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47
Parents: 0154587
Author: Doris Xin <do...@gmail.com>
Authored: Thu Jun 12 19:44:27 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Jun 12 19:44:27 2014 -0700

----------------------------------------------------------------------
 core/pom.xml                                    |   5 +
 .../main/scala/org/apache/spark/rdd/RDD.scala   |  52 +++---
 .../spark/util/random/RandomSampler.scala       |   2 +-
 .../spark/util/random/SamplingUtils.scala       |  55 ++++++
 .../scala/org/apache/spark/rdd/RDDSuite.scala   |  35 ++--
 .../spark/util/random/SamplingUtilsSuite.scala  |  46 +++++
 project/SparkBuild.scala                        |   1 +
 python/pyspark/rdd.py                           | 167 ++++++++++++-------
 8 files changed, 263 insertions(+), 100 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1de1d703/core/pom.xml
----------------------------------------------------------------------
diff --git a/core/pom.xml b/core/pom.xml
index c3d6b00..be56911 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -68,6 +68,11 @@
       <artifactId>commons-lang3</artifactId>
     </dependency>
     <dependency>
+      <groupId>org.apache.commons</groupId>
+      <artifactId>commons-math3</artifactId>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
       <groupId>com.google.code.findbugs</groupId>
       <artifactId>jsr305</artifactId>
     </dependency>

http://git-wip-us.apache.org/repos/asf/spark/blob/1de1d703/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index b6fc4b1..446f369 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -42,7 +42,7 @@ import org.apache.spark.partial.PartialResult
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.{BoundedPriorityQueue, Utils}
 import org.apache.spark.util.collection.OpenHashMap
-import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler}
+import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
 
 /**
  * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -378,46 +378,56 @@ abstract class RDD[T: ClassTag](
     }.toArray
   }
 
-  def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
-  {
-    var fraction = 0.0
-    var total = 0
-    val multiplier = 3.0
-    val initialCount = this.count()
-    var maxSelected = 0
+  /**
+   * Return a fixed-size sampled subset of this RDD in an array
+   *
+   * @param withReplacement whether sampling is done with replacement
+   * @param num size of the returned sample
+   * @param seed seed for the random number generator
+   * @return sample of specified size in an array
+   */
+  def takeSample(withReplacement: Boolean,
+      num: Int,
+      seed: Long = Utils.random.nextLong): Array[T] = {
+    val numStDev =  10.0
 
     if (num < 0) {
       throw new IllegalArgumentException("Negative number of elements requested")
+    } else if (num == 0) {
+      return new Array[T](0)
     }
 
+    val initialCount = this.count()
     if (initialCount == 0) {
       return new Array[T](0)
     }
 
-    if (initialCount > Integer.MAX_VALUE - 1) {
-      maxSelected = Integer.MAX_VALUE - 1
-    } else {
-      maxSelected = initialCount.toInt
+    val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
+    if (num > maxSampleSize) {
+      throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
+        s"$numStDev * math.sqrt(Int.MaxValue)")
     }
 
-    if (num > initialCount && !withReplacement) {
-      total = maxSelected
-      fraction = multiplier * (maxSelected + 1) / initialCount
-    } else {
-      fraction = multiplier * (num + 1) / initialCount
-      total = num
+    val rand = new Random(seed)
+    if (!withReplacement && num >= initialCount) {
+      return Utils.randomizeInPlace(this.collect(), rand)
     }
 
-    val rand = new Random(seed)
+    val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
+      withReplacement)
+
     var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
 
     // If the first sample didn't turn out large enough, keep trying to take samples;
     // this shouldn't happen often because we use a big multiplier for the initial size
-    while (samples.length < total) {
+    var numIters = 0
+    while (samples.length < num) {
+      logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
       samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
+      numIters += 1
     }
 
-    Utils.randomizeInPlace(samples, rand).take(total)
+    Utils.randomizeInPlace(samples, rand).take(num)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1de1d703/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 4dc8ada..247f101 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
   }
 
   /**
-   *  Return a sampler with is the complement of the range specified of the current sampler.
+   *  Return a sampler that is the complement of the range specified of the current sampler.
    */
   def cloneComplement():  BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1de1d703/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
new file mode 100644
index 0000000..a79e3ee
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+private[spark] object SamplingUtils {
+
+  /**
+   * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
+   * the time.
+   *
+   * How the sampling rate is determined:
+   * Let p = num / total, where num is the sample size and total is the total number of
+   * datapoints in the RDD. We're trying to compute q > p such that
+   *   - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
+   *     where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
+   *     i.e. the failure rate of not having a sufficiently large sample < 0.0001.
+   *     Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
+   *     num > 12, but we need a slightly larger q (9 empirically determined).
+   *   - when sampling without replacement, we're drawing each datapoint with prob_i
+   *     ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
+   *     rate, where success rate is defined the same as in sampling with replacement.
+   *
+   * @param sampleSizeLowerBound sample size
+   * @param total size of RDD
+   * @param withReplacement whether sampling with replacement
+   * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
+   */
+  def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long,
+      withReplacement: Boolean): Double = {
+    val fraction = sampleSizeLowerBound.toDouble / total
+    if (withReplacement) {
+      val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
+      fraction + numStDev * math.sqrt(fraction / total)
+    } else {
+      val delta = 1e-4
+      val gamma = - math.log(delta) / total
+      math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1de1d703/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 2e2ccc5..e94a1e7 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -505,55 +505,56 @@ class RDDSuite extends FunSuite with SharedSparkContext {
   }
 
   test("takeSample") {
-    val data = sc.parallelize(1 to 100, 2)
+    val n = 1000000
+    val data = sc.parallelize(1 to n, 2)
     
     for (num <- List(5, 20, 100)) {
       val sample = data.takeSample(withReplacement=false, num=num)
       assert(sample.size === num)        // Got exactly num elements
       assert(sample.toSet.size === num)  // Elements are distinct
-      assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+      assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
     }
     for (seed <- 1 to 5) {
       val sample = data.takeSample(withReplacement=false, 20, seed)
       assert(sample.size === 20)        // Got exactly 20 elements
       assert(sample.toSet.size === 20)  // Elements are distinct
-      assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+      assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
     }
     for (seed <- 1 to 5) {
-      val sample = data.takeSample(withReplacement=false, 200, seed)
+      val sample = data.takeSample(withReplacement=false, 100, seed)
       assert(sample.size === 100)        // Got only 100 elements
       assert(sample.toSet.size === 100)  // Elements are distinct
-      assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+      assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
     }
     for (seed <- 1 to 5) {
       val sample = data.takeSample(withReplacement=true, 20, seed)
       assert(sample.size === 20)        // Got exactly 20 elements
-      assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+      assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
     }
     {
       val sample = data.takeSample(withReplacement=true, num=20)
       assert(sample.size === 20)        // Got exactly 100 elements
       assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
-      assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+      assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
     }
     {
-      val sample = data.takeSample(withReplacement=true, num=100)
-      assert(sample.size === 100)        // Got exactly 100 elements
+      val sample = data.takeSample(withReplacement=true, num=n)
+      assert(sample.size === n)        // Got exactly 100 elements
       // Chance of getting all distinct elements is astronomically low, so test we got < 100
-      assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
-      assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+      assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
+      assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
     }
     for (seed <- 1 to 5) {
-      val sample = data.takeSample(withReplacement=true, 100, seed)
-      assert(sample.size === 100)        // Got exactly 100 elements
+      val sample = data.takeSample(withReplacement=true, n, seed)
+      assert(sample.size === n)        // Got exactly 100 elements
       // Chance of getting all distinct elements is astronomically low, so test we got < 100
-      assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+      assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
     }
     for (seed <- 1 to 5) {
-      val sample = data.takeSample(withReplacement=true, 200, seed)
-      assert(sample.size === 200)        // Got exactly 200 elements
+      val sample = data.takeSample(withReplacement=true, 2 * n, seed)
+      assert(sample.size === 2 * n)        // Got exactly 200 elements
       // Chance of getting all distinct elements is still quite low, so test we got < 100
-      assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+      assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1de1d703/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
new file mode 100644
index 0000000..accfe2e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
+import org.scalatest.FunSuite
+
+class SamplingUtilsSuite extends FunSuite {
+
+  test("computeFraction") {
+    // test that the computed fraction guarantees enough data points
+    // in the sample with a failure rate <= 0.0001
+    val n = 100000
+
+    for (s <- 1 to 15) {
+      val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
+      val poisson = new PoissonDistribution(frac * n)
+      assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+    }
+    for (s <- List(20, 100, 1000)) {
+      val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
+      val poisson = new PoissonDistribution(frac * n)
+      assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+    }
+    for (s <- List(1, 10, 100, 1000)) {
+      val frac = SamplingUtils.computeFractionForSampleSize(s, n, false)
+      val binomial = new BinomialDistribution(n, frac)
+      assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1de1d703/project/SparkBuild.scala
----------------------------------------------------------------------
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 8b4885d..2d60a44 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -349,6 +349,7 @@ object SparkBuild extends Build {
     libraryDependencies ++= Seq(
         "com.google.guava"           % "guava"            % "14.0.1",
         "org.apache.commons"         % "commons-lang3"    % "3.3.2",
+        "org.apache.commons"         % "commons-math3"    % "3.3" % "test",
         "com.google.code.findbugs"   % "jsr305"           % "1.3.9",
         "log4j"                      % "log4j"            % "1.2.17",
         "org.slf4j"                  % "slf4j-api"        % slf4jVersion,

http://git-wip-us.apache.org/repos/asf/spark/blob/1de1d703/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 735389c..ddd2285 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -31,6 +31,7 @@ from threading import Thread
 import warnings
 import heapq
 from random import Random
+from math import sqrt, log
 
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
     BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
@@ -202,9 +203,9 @@ class RDD(object):
 
     def persist(self, storageLevel):
         """
-        Set this RDD's storage level to persist its values across operations after the first time
-        it is computed. This can only be used to assign a new storage level if the RDD does not
-        have a storage level set yet.
+        Set this RDD's storage level to persist its values across operations
+        after the first time it is computed. This can only be used to assign
+        a new storage level if the RDD does not have a storage level set yet.
         """
         self.is_cached = True
         javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
@@ -213,7 +214,8 @@ class RDD(object):
 
     def unpersist(self):
         """
-        Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+        Mark the RDD as non-persistent, and remove all blocks for it from
+        memory and disk.
         """
         self.is_cached = False
         self._jrdd.unpersist()
@@ -357,48 +359,87 @@ class RDD(object):
     # this is ported from scala/spark/RDD.scala
     def takeSample(self, withReplacement, num, seed=None):
         """
-        Return a fixed-size sampled subset of this RDD (currently requires numpy).
+        Return a fixed-size sampled subset of this RDD (currently requires
+        numpy).
 
-        >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
-        [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
+        >>> rdd = sc.parallelize(range(0, 10))
+        >>> len(rdd.takeSample(True, 20, 1))
+        20
+        >>> len(rdd.takeSample(False, 5, 2))
+        5
+        >>> len(rdd.takeSample(False, 15, 3))
+        10
         """
+        numStDev = 10.0
+
+        if num < 0:
+            raise ValueError("Sample size cannot be negative.")
+        elif num == 0:
+            return []
 
-        fraction = 0.0
-        total = 0
-        multiplier = 3.0
         initialCount = self.count()
-        maxSelected = 0
+        if initialCount == 0:
+            return []
 
-        if (num < 0):
-            raise ValueError
+        rand = Random(seed)
 
-        if (initialCount == 0):
-            return list()
+        if (not withReplacement) and num >= initialCount:
+            # shuffle current RDD and return
+            samples = self.collect()
+            rand.shuffle(samples)
+            return samples
 
-        if initialCount > sys.maxint - 1:
-            maxSelected = sys.maxint - 1
-        else:
-            maxSelected = initialCount
-
-        if num > initialCount and not withReplacement:
-            total = maxSelected
-            fraction = multiplier * (maxSelected + 1) / initialCount
-        else:
-            fraction = multiplier * (num + 1) / initialCount
-            total = num
+        maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
+        if num > maxSampleSize:
+            raise ValueError("Sample size cannot be greater than %d." % maxSampleSize)
 
+        fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement)
         samples = self.sample(withReplacement, fraction, seed).collect()
 
         # If the first sample didn't turn out large enough, keep trying to take samples;
         # this shouldn't happen often because we use a big multiplier for their initial size.
         # See: scala/spark/RDD.scala
-        rand = Random(seed)
-        while len(samples) < total:
-            samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect()
-
-        sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
-        sampler.shuffle(samples)
-        return samples[0:total]
+        while len(samples) < num:
+            # TODO: add log warning for when more than one iteration was run
+            seed = rand.randint(0, sys.maxint)
+            samples = self.sample(withReplacement, fraction, seed).collect()
+
+        rand.shuffle(samples)
+
+        return samples[0:num]
+
+    @staticmethod
+    def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement):
+        """
+        Returns a sampling rate that guarantees a sample of
+        size >= sampleSizeLowerBound 99.99% of the time.
+
+        How the sampling rate is determined:
+        Let p = num / total, where num is the sample size and total is the
+        total number of data points in the RDD. We're trying to compute
+        q > p such that
+          - when sampling with replacement, we're drawing each data point
+            with prob_i ~ Pois(q), where we want to guarantee
+            Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to
+            total), i.e. the failure rate of not having a sufficiently large
+            sample < 0.0001. Setting q = p + 5 * sqrt(p/total) is sufficient
+            to guarantee 0.9999 success rate for num > 12, but we need a
+            slightly larger q (9 empirically determined).
+          - when sampling without replacement, we're drawing each data point
+            with prob_i ~ Binomial(total, fraction) and our choice of q
+            guarantees 1-delta, or 0.9999 success rate, where success rate is
+            defined the same as in sampling with replacement.
+        """
+        fraction = float(sampleSizeLowerBound) / total
+        if withReplacement:
+            numStDev = 5
+            if (sampleSizeLowerBound < 12):
+                numStDev = 9
+            return fraction + numStDev * sqrt(fraction / total)
+        else:
+            delta = 0.00005
+            gamma = - log(delta) / total
+            return min(1, fraction + gamma + sqrt(gamma * gamma + 2 * gamma * fraction))
 
     def union(self, other):
         """
@@ -422,8 +463,8 @@ class RDD(object):
 
     def intersection(self, other):
         """
-        Return the intersection of this RDD and another one. The output will not
-        contain any duplicate elements, even if the input RDDs did.
+        Return the intersection of this RDD and another one. The output will
+        not contain any duplicate elements, even if the input RDDs did.
 
         Note that this method performs a shuffle internally.
 
@@ -665,8 +706,8 @@ class RDD(object):
         modify C{t2}.
 
         The first function (seqOp) can return a different result type, U, than
-        the type of this RDD. Thus, we need one operation for merging a T into an U
-        and one operation for merging two U
+        the type of this RDD. Thus, we need one operation for merging a T into
+        an U and one operation for merging two U
 
         >>> seqOp = (lambda x, y: (x[0] + y, x[1] + 1))
         >>> combOp = (lambda x, y: (x[0] + y[0], x[1] + y[1]))
@@ -759,8 +800,9 @@ class RDD(object):
 
     def sampleStdev(self):
         """
-        Compute the sample standard deviation of this RDD's elements (which corrects for bias in
-        estimating the standard deviation by dividing by N-1 instead of N).
+        Compute the sample standard deviation of this RDD's elements (which
+        corrects for bias in estimating the standard deviation by dividing by
+        N-1 instead of N).
 
         >>> sc.parallelize([1, 2, 3]).sampleStdev()
         1.0
@@ -769,8 +811,8 @@ class RDD(object):
 
     def sampleVariance(self):
         """
-        Compute the sample variance of this RDD's elements (which corrects for bias in
-        estimating the variance by dividing by N-1 instead of N).
+        Compute the sample variance of this RDD's elements (which corrects
+        for bias in estimating the variance by dividing by N-1 instead of N).
 
         >>> sc.parallelize([1, 2, 3]).sampleVariance()
         1.0
@@ -822,8 +864,8 @@ class RDD(object):
 
     def takeOrdered(self, num, key=None):
         """
-        Get the N elements from a RDD ordered in ascending order or as specified
-        by the optional key function.
+        Get the N elements from a RDD ordered in ascending order or as
+        specified by the optional key function.
 
         >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
         [1, 2, 3, 4, 5, 6]
@@ -912,8 +954,9 @@ class RDD(object):
 
     def saveAsPickleFile(self, path, batchSize=10):
         """
-        Save this RDD as a SequenceFile of serialized objects. The serializer used is
-        L{pyspark.serializers.PickleSerializer}, default batch size is 10.
+        Save this RDD as a SequenceFile of serialized objects. The serializer
+        used is L{pyspark.serializers.PickleSerializer}, default batch size
+        is 10.
 
         >>> tmpFile = NamedTemporaryFile(delete=True)
         >>> tmpFile.close()
@@ -1195,9 +1238,10 @@ class RDD(object):
 
     def foldByKey(self, zeroValue, func, numPartitions=None):
         """
-        Merge the values for each key using an associative function "func" and a neutral "zeroValue"
-        which may be added to the result an arbitrary number of times, and must not change
-        the result (e.g., 0 for addition, or 1 for multiplication.).
+        Merge the values for each key using an associative function "func"
+        and a neutral "zeroValue" which may be added to the result an
+        arbitrary number of times, and must not change the result
+        (e.g., 0 for addition, or 1 for multiplication.).
 
         >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
         >>> from operator import add
@@ -1217,8 +1261,8 @@ class RDD(object):
         Hash-partitions the resulting RDD with into numPartitions partitions.
 
         Note: If you are grouping in order to perform an aggregation (such as a
-        sum or average) over each key, using reduceByKey will provide much better
-        performance.
+        sum or average) over each key, using reduceByKey will provide much
+        better performance.
 
         >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
         >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
@@ -1278,8 +1322,8 @@ class RDD(object):
     def cogroup(self, other, numPartitions=None):
         """
         For each key k in C{self} or C{other}, return a resulting RDD that
-        contains a tuple with the list of values for that key in C{self} as well
-        as C{other}.
+        contains a tuple with the list of values for that key in C{self} as
+        well as C{other}.
 
         >>> x = sc.parallelize([("a", 1), ("b", 4)])
         >>> y = sc.parallelize([("a", 2)])
@@ -1290,8 +1334,8 @@ class RDD(object):
 
     def subtractByKey(self, other, numPartitions=None):
         """
-        Return each (key, value) pair in C{self} that has no pair with matching key
-        in C{other}.
+        Return each (key, value) pair in C{self} that has no pair with matching
+        key in C{other}.
 
         >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)])
         >>> y = sc.parallelize([("a", 3), ("c", None)])
@@ -1329,10 +1373,10 @@ class RDD(object):
         """
          Return a new RDD that has exactly numPartitions partitions.
 
-         Can increase or decrease the level of parallelism in this RDD. Internally, this uses
-         a shuffle to redistribute data.
-         If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
-         which can avoid performing a shuffle.
+         Can increase or decrease the level of parallelism in this RDD.
+         Internally, this uses a shuffle to redistribute data.
+         If you are decreasing the number of partitions in this RDD, consider
+         using `coalesce`, which can avoid performing a shuffle.
          >>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4)
          >>> sorted(rdd.glom().collect())
          [[1], [2, 3], [4, 5], [6, 7]]
@@ -1357,9 +1401,10 @@ class RDD(object):
 
     def zip(self, other):
         """
-        Zips this RDD with another one, returning key-value pairs with the first element in each RDD
-        second element in each RDD, etc. Assumes that the two RDDs have the same number of
-        partitions and the same number of elements in each partition (e.g. one was made through
+        Zips this RDD with another one, returning key-value pairs with the
+        first element in each RDD second element in each RDD, etc. Assumes
+        that the two RDDs have the same number of partitions and the same
+        number of elements in each partition (e.g. one was made through
         a map on the other).
 
         >>> x = sc.parallelize(range(0,5))