You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/07/29 21:49:50 UTC

git commit: [SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size

Repository: spark
Updated Branches:
  refs/heads/master f0d880e28 -> dc9653641


[SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size

Implemented stratified sampling that guarantees exact sample size using ScaRSR with two passes over the RDD for sampling without replacement and three passes for sampling with replacement.

Author: Doris Xin <do...@gmail.com>
Author: Xiangrui Meng <me...@databricks.com>

Closes #1025 from dorx/stratified and squashes the following commits:

245439e [Doris Xin] moved minSamplingRate to getUpperBound
eaf5771 [Doris Xin] bug fixes.
17a381b [Doris Xin] fixed a merge issue and a failed unit
ea7d27f [Doris Xin] merge master
b223529 [Xiangrui Meng] use approx bounds for poisson fix poisson mean for waitlisting add unit tests for Java
b3013a4 [Xiangrui Meng] move math3 back to test scope
eecee5f [Doris Xin] Merge branch 'master' into stratified
f4c21f3 [Doris Xin] Reviewer comments
a10e68d [Doris Xin] style fix
a2bf756 [Doris Xin] Merge branch 'master' into stratified
680b677 [Doris Xin] use mapPartitionWithIndex instead
9884a9f [Doris Xin] style fix
bbfb8c9 [Doris Xin] Merge branch 'master' into stratified
ee9d260 [Doris Xin] addressed reviewer comments
6b5b10b [Doris Xin] Merge branch 'master' into stratified
254e03c [Doris Xin] minor fixes and Java API.
4ad516b [Doris Xin] remove unused imports from PairRDDFunctions
bd9dc6e [Doris Xin] unit bug and style violation fixed
1fe1cff [Doris Xin] Changed fractionByKey to a map to enable arg check
944a10c [Doris Xin] [SPARK-2145] Add lower bound on sampling rate
0214a76 [Doris Xin] cleanUp
90d94c0 [Doris Xin] merge master
9e74ab5 [Doris Xin] Separated out most of the logic in sampleByKey
7327611 [Doris Xin] merge master
50581fc [Doris Xin] added a TODO for logging in python
46f6c8c [Doris Xin] fixed the NPE caused by closures being cleaned before being passed into the aggregate function
7e1a481 [Doris Xin] changed the permission on SamplingUtil
1d413ce [Doris Xin] fixed checkstyle issues
9ee94ee [Doris Xin] [SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size
e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample
7cab53a [Doris Xin] fixed import bug in rdd.py
ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD
1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dc965364
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dc965364
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dc965364

Branch: refs/heads/master
Commit: dc9653641f8806960d79652afa043c3fb84f25d2
Parents: f0d880e
Author: Doris Xin <do...@gmail.com>
Authored: Tue Jul 29 12:49:44 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Jul 29 12:49:44 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/java/JavaPairRDD.scala |  69 +++-
 .../org/apache/spark/rdd/PairRDDFunctions.scala |  54 +++-
 .../spark/util/random/SamplingUtils.scala       |  74 ++++-
 .../util/random/StratifiedSamplingUtils.scala   | 316 +++++++++++++++++++
 .../java/org/apache/spark/JavaAPISuite.java     |  37 +++
 .../spark/rdd/PairRDDFunctionsSuite.scala       | 116 +++++++
 pom.xml                                         |   6 +
 7 files changed, 656 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/dc965364/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 4f30814..31bf8dc 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.api.java
 
-import java.util.{Comparator, List => JList}
+import java.util.{Comparator, List => JList, Map => JMap}
 import java.lang.{Iterable => JIterable}
 
 import scala.collection.JavaConversions._
@@ -130,6 +130,73 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
     new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
 
   /**
+   * Return a subset of this RDD sampled by key (via stratified sampling).
+   *
+   * Create a sample of this RDD using variable sampling rates for different keys as specified by
+   * `fractions`, a key to sampling rate map.
+   *
+   * If `exact` is set to false, create the sample via simple random sampling, with one pass
+   * over the RDD, to produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
+   * the RDD to create a sample size that's exactly equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values.
+   */
+  def sampleByKey(withReplacement: Boolean,
+      fractions: JMap[K, Double],
+      exact: Boolean,
+      seed: Long): JavaPairRDD[K, V] =
+    new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed))
+
+  /**
+   * Return a subset of this RDD sampled by key (via stratified sampling).
+   *
+   * Create a sample of this RDD using variable sampling rates for different keys as specified by
+   * `fractions`, a key to sampling rate map.
+   *
+   * If `exact` is set to false, create the sample via simple random sampling, with one pass
+   * over the RDD, to produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
+   * the RDD to create a sample size that's exactly equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values.
+   *
+   * Use Utils.random.nextLong as the default seed for the random number generator
+   */
+  def sampleByKey(withReplacement: Boolean,
+      fractions: JMap[K, Double],
+      exact: Boolean): JavaPairRDD[K, V] =
+    sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong)
+
+  /**
+   * Return a subset of this RDD sampled by key (via stratified sampling).
+   *
+   * Create a sample of this RDD using variable sampling rates for different keys as specified by
+   * `fractions`, a key to sampling rate map.
+   *
+   * Produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
+   * simple random sampling.
+   */
+  def sampleByKey(withReplacement: Boolean,
+      fractions: JMap[K, Double],
+      seed: Long): JavaPairRDD[K, V] =
+    sampleByKey(withReplacement, fractions, false, seed)
+
+  /**
+   * Return a subset of this RDD sampled by key (via stratified sampling).
+   *
+   * Create a sample of this RDD using variable sampling rates for different keys as specified by
+   * `fractions`, a key to sampling rate map.
+   *
+   * Produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
+   * simple random sampling.
+   *
+   * Use Utils.random.nextLong as the default seed for the random number generator
+   */
+  def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
+    sampleByKey(withReplacement, fractions, false, Utils.random.nextLong)
+
+  /**
    * Return the union of this RDD and another one. Any identical elements will appear multiple
    * times (use `.distinct()` to eliminate them).
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/dc965364/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index c04d162..1af4e5f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -19,12 +19,10 @@ package org.apache.spark.rdd
 
 import java.nio.ByteBuffer
 import java.text.SimpleDateFormat
-import java.util.Date
-import java.util.{HashMap => JHashMap}
+import java.util.{Date, HashMap => JHashMap}
 
+import scala.collection.{Map, mutable}
 import scala.collection.JavaConversions._
-import scala.collection.Map
-import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.reflect.ClassTag
 
@@ -34,19 +32,19 @@ import org.apache.hadoop.fs.FileSystem
 import org.apache.hadoop.io.SequenceFile.CompressionType
 import org.apache.hadoop.io.compress.CompressionCodec
 import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
-import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob,
+import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat,
 RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil}
-import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
 
 import org.apache.spark._
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.SparkHadoopWriter
 import org.apache.spark.Partitioner.defaultPartitioner
 import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.util.random.StratifiedSamplingUtils
 
 /**
  * Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -196,6 +194,41 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
   }
 
   /**
+   * Return a subset of this RDD sampled by key (via stratified sampling).
+   *
+   * Create a sample of this RDD using variable sampling rates for different keys as specified by
+   * `fractions`, a key to sampling rate map.
+   *
+   * If `exact` is set to false, create the sample via simple random sampling, with one pass
+   * over the RDD, to produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values; otherwise, use
+   * additional passes over the RDD to create a sample size that's exactly equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling
+   * without replacement, we need one additional pass over the RDD to guarantee sample size;
+   * when sampling with replacement, we need two additional passes.
+   *
+   * @param withReplacement whether to sample with or without replacement
+   * @param fractions map of specific keys to sampling rates
+   * @param seed seed for the random number generator
+   * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key
+   * @return RDD containing the sampled subset
+   */
+  def sampleByKey(withReplacement: Boolean,
+      fractions: Map[K, Double],
+      exact: Boolean = false,
+      seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
+
+    require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
+
+    val samplingFunc = if (withReplacement) {
+      StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed)
+    } else {
+      StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed)
+    }
+    self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
+  }
+
+  /**
    * Merge the values for each key using an associative reduce function. This will also perform
    * the merging locally on each mapper before sending results to a reducer, similarly to a
    * "combiner" in MapReduce.
@@ -531,6 +564,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
 
   /**
    * Return the key-value pairs in this RDD to the master as a Map.
+   *
+   * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only
+   *          one value per key is preserved in the map returned)
    */
   def collectAsMap(): Map[K, V] = {
     val data = self.collect()

http://git-wip-us.apache.org/repos/asf/spark/blob/dc965364/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index d10141b..c9a864a 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -81,6 +81,9 @@ private[spark] object SamplingUtils {
    *     ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
    *     rate, where success rate is defined the same as in sampling with replacement.
    *
+   * The smallest sampling rate supported is 1e-10 (in order to avoid running into the limit of the
+   * RNG's resolution).
+   *
    * @param sampleSizeLowerBound sample size
    * @param total size of RDD
    * @param withReplacement whether sampling with replacement
@@ -88,14 +91,73 @@ private[spark] object SamplingUtils {
    */
   def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long,
       withReplacement: Boolean): Double = {
-    val fraction = sampleSizeLowerBound.toDouble / total
     if (withReplacement) {
-      val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
-      fraction + numStDev * math.sqrt(fraction / total)
+      PoissonBounds.getUpperBound(sampleSizeLowerBound) / total
     } else {
-      val delta = 1e-4
-      val gamma = - math.log(delta) / total
-      math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
+      val fraction = sampleSizeLowerBound.toDouble / total
+      BinomialBounds.getUpperBound(1e-4, total, fraction)
     }
   }
 }
+
+/**
+ * Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact
+ * sample sizes with high confidence when sampling with replacement.
+ */
+private[spark] object PoissonBounds {
+
+  /**
+   * Returns a lambda such that Pr[X > s] is very small, where X ~ Pois(lambda).
+   */
+  def getLowerBound(s: Double): Double = {
+    math.max(s - numStd(s) * math.sqrt(s), 1e-15)
+  }
+
+  /**
+   * Returns a lambda such that Pr[X < s] is very small, where X ~ Pois(lambda).
+   *
+   * @param s sample size
+   */
+  def getUpperBound(s: Double): Double = {
+    math.max(s + numStd(s) * math.sqrt(s), 1e-10)
+  }
+
+  private def numStd(s: Double): Double = {
+    // TODO: Make it tighter.
+    if (s < 6.0) {
+      12.0
+    } else if (s < 16.0) {
+      9.0
+    } else {
+      6.0
+    }
+  }
+}
+
+/**
+ * Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact
+ * sample size with high confidence when sampling without replacement.
+ */
+private[spark] object BinomialBounds {
+
+  val minSamplingRate = 1e-10
+
+  /**
+   * Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`,
+   * it is very unlikely to have more than `fraction * n` successes.
+   */
+  def getLowerBound(delta: Double, n: Long, fraction: Double): Double = {
+    val gamma = - math.log(delta) / n * (2.0 / 3.0)
+    fraction + gamma - math.sqrt(gamma * gamma + 3 * gamma * fraction)
+  }
+
+  /**
+   * Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`,
+   * it is very unlikely to have less than `fraction * n` successes.
+   */
+  def getUpperBound(delta: Double, n: Long, fraction: Double): Double = {
+    val gamma = - math.log(delta) / n
+    math.min(1,
+      math.max(minSamplingRate, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dc965364/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
new file mode 100644
index 0000000..8f95d7c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+import scala.collection.Map
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.RDD
+
+/**
+ * Auxiliary functions and data structures for the sampleByKey method in PairRDDFunctions.
+ *
+ * Essentially, when exact sample size is necessary, we make additional passes over the RDD to
+ * compute the exact threshold value to use for each stratum to guarantee exact sample size with
+ * high probability. This is achieved by maintaining a waitlist of size O(log(s)), where s is the
+ * desired sample size for each stratum.
+ *
+ * Like in simple random sampling, we generate a random value for each item from the
+ * uniform  distribution [0.0, 1.0]. All items with values <= min(values of items in the waitlist)
+ * are accepted into the sample instantly. The threshold for instant accept is designed so that
+ * s - numAccepted = O(sqrt(s)), where s is again the desired sample size. Thus, by maintaining a
+ * waitlist size = O(sqrt(s)), we will be able to create a sample of the exact size s by adding
+ * a portion of the waitlist to the set of items that are instantly accepted. The exact threshold
+ * is computed by sorting the values in the waitlist and picking the value at (s - numAccepted).
+ *
+ * Note that since we use the same seed for the RNG when computing the thresholds and the actual
+ * sample, our computed thresholds are guaranteed to produce the desired sample size.
+ *
+ * For more theoretical background on the sampling techniques used here, please refer to
+ * http://jmlr.org/proceedings/papers/v28/meng13a.html
+ */
+
+private[spark] object StratifiedSamplingUtils extends Logging {
+
+  /**
+   * Count the number of items instantly accepted and generate the waitlist for each stratum.
+   *
+   * This is only invoked when exact sample size is required.
+   */
+  def getAcceptanceResults[K, V](rdd: RDD[(K, V)],
+      withReplacement: Boolean,
+      fractions: Map[K, Double],
+      counts: Option[Map[K, Long]],
+      seed: Long): mutable.Map[K, AcceptanceResult] = {
+    val combOp = getCombOp[K]
+    val mappedPartitionRDD = rdd.mapPartitionsWithIndex { case (partition, iter) =>
+      val zeroU: mutable.Map[K, AcceptanceResult] = new mutable.HashMap[K, AcceptanceResult]()
+      val rng = new RandomDataGenerator()
+      rng.reSeed(seed + partition)
+      val seqOp = getSeqOp(withReplacement, fractions, rng, counts)
+      Iterator(iter.aggregate(zeroU)(seqOp, combOp))
+    }
+    mappedPartitionRDD.reduce(combOp)
+  }
+
+  /**
+   * Returns the function used by aggregate to collect sampling statistics for each partition.
+   */
+  def getSeqOp[K, V](withReplacement: Boolean,
+      fractions: Map[K, Double],
+      rng: RandomDataGenerator,
+      counts: Option[Map[K, Long]]):
+    (mutable.Map[K, AcceptanceResult], (K, V)) => mutable.Map[K, AcceptanceResult] = {
+    val delta = 5e-5
+    (result: mutable.Map[K, AcceptanceResult], item: (K, V)) => {
+      val key = item._1
+      val fraction = fractions(key)
+      if (!result.contains(key)) {
+        result += (key -> new AcceptanceResult())
+      }
+      val acceptResult = result(key)
+
+      if (withReplacement) {
+        // compute acceptBound and waitListBound only if they haven't been computed already
+        // since they don't change from iteration to iteration.
+        // TODO change this to the streaming version
+        if (acceptResult.areBoundsEmpty) {
+          val n = counts.get(key)
+          val sampleSize = math.ceil(n * fraction).toLong
+          val lmbd1 = PoissonBounds.getLowerBound(sampleSize)
+          val lmbd2 = PoissonBounds.getUpperBound(sampleSize)
+          acceptResult.acceptBound = lmbd1 / n
+          acceptResult.waitListBound = (lmbd2 - lmbd1) / n
+        }
+        val acceptBound = acceptResult.acceptBound
+        val copiesAccepted = if (acceptBound == 0.0) 0L else rng.nextPoisson(acceptBound)
+        if (copiesAccepted > 0) {
+          acceptResult.numAccepted += copiesAccepted
+        }
+        val copiesWaitlisted = rng.nextPoisson(acceptResult.waitListBound)
+        if (copiesWaitlisted > 0) {
+          acceptResult.waitList ++= ArrayBuffer.fill(copiesWaitlisted)(rng.nextUniform())
+        }
+      } else {
+        // We use the streaming version of the algorithm for sampling without replacement to avoid
+        // using an extra pass over the RDD for computing the count.
+        // Hence, acceptBound and waitListBound change on every iteration.
+        acceptResult.acceptBound =
+          BinomialBounds.getLowerBound(delta, acceptResult.numItems, fraction)
+        acceptResult.waitListBound =
+          BinomialBounds.getUpperBound(delta, acceptResult.numItems, fraction)
+
+        val x = rng.nextUniform()
+        if (x < acceptResult.acceptBound) {
+          acceptResult.numAccepted += 1
+        } else if (x < acceptResult.waitListBound) {
+          acceptResult.waitList += x
+        }
+      }
+      acceptResult.numItems += 1
+      result
+    }
+  }
+
+  /**
+   * Returns the function used combine results returned by seqOp from different partitions.
+   */
+  def getCombOp[K]: (mutable.Map[K, AcceptanceResult], mutable.Map[K, AcceptanceResult])
+    => mutable.Map[K, AcceptanceResult] = {
+    (result1: mutable.Map[K, AcceptanceResult], result2: mutable.Map[K, AcceptanceResult]) => {
+      // take union of both key sets in case one partition doesn't contain all keys
+      result1.keySet.union(result2.keySet).foreach { key =>
+        // Use result2 to keep the combined result since r1 is usual empty
+        val entry1 = result1.get(key)
+        if (result2.contains(key)) {
+          result2(key).merge(entry1)
+        } else {
+          if (entry1.isDefined) {
+            result2 += (key -> entry1.get)
+          }
+        }
+      }
+      result2
+    }
+  }
+
+  /**
+   * Given the result returned by getCounts, determine the threshold for accepting items to
+   * generate exact sample size.
+   *
+   * To do so, we compute sampleSize = math.ceil(size * samplingRate) for each stratum and compare
+   * it to the number of items that were accepted instantly and the number of items in the waitlist
+   * for that stratum. Most of the time, numAccepted <= sampleSize <= (numAccepted + numWaitlisted),
+   * which means we need to sort the elements in the waitlist by their associated values in order
+   * to find the value T s.t. |{elements in the stratum whose associated values <= T}| = sampleSize.
+   * Note that all elements in the waitlist have values >= bound for instant accept, so a T value
+   * in the waitlist range would allow all elements that were instantly accepted on the first pass
+   * to be included in the sample.
+   */
+  def computeThresholdByKey[K](finalResult: Map[K, AcceptanceResult],
+      fractions: Map[K, Double]): Map[K, Double] = {
+    val thresholdByKey = new mutable.HashMap[K, Double]()
+    for ((key, acceptResult) <- finalResult) {
+      val sampleSize = math.ceil(acceptResult.numItems * fractions(key)).toLong
+      if (acceptResult.numAccepted > sampleSize) {
+        logWarning("Pre-accepted too many")
+        thresholdByKey += (key -> acceptResult.acceptBound)
+      } else {
+        val numWaitListAccepted = (sampleSize - acceptResult.numAccepted).toInt
+        if (numWaitListAccepted >= acceptResult.waitList.size) {
+          logWarning("WaitList too short")
+          thresholdByKey += (key -> acceptResult.waitListBound)
+        } else {
+          thresholdByKey += (key -> acceptResult.waitList.sorted.apply(numWaitListAccepted))
+        }
+      }
+    }
+    thresholdByKey
+  }
+
+  /**
+   * Return the per partition sampling function used for sampling without replacement.
+   *
+   * When exact sample size is required, we make an additional pass over the RDD to determine the
+   * exact sampling rate that guarantees sample size with high confidence.
+   *
+   * The sampling function has a unique seed per partition.
+   */
+  def getBernoulliSamplingFunction[K, V](rdd: RDD[(K,  V)],
+      fractions: Map[K, Double],
+      exact: Boolean,
+      seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = {
+    var samplingRateByKey = fractions
+    if (exact) {
+      // determine threshold for each stratum and resample
+      val finalResult = getAcceptanceResults(rdd, false, fractions, None, seed)
+      samplingRateByKey = computeThresholdByKey(finalResult, fractions)
+    }
+    (idx: Int, iter: Iterator[(K, V)]) => {
+      val rng = new RandomDataGenerator
+      rng.reSeed(seed + idx)
+      // Must use the same invoke pattern on the rng as in getSeqOp for without replacement
+      // in order to generate the same sequence of random numbers when creating the sample
+      iter.filter(t => rng.nextUniform() < samplingRateByKey(t._1))
+    }
+  }
+
+  /**
+   * Return the per partition sampling function used for sampling with replacement.
+   *
+   * When exact sample size is required, we make two additional passed over the RDD to determine
+   * the exact sampling rate that guarantees sample size with high confidence. The first pass
+   * counts the number of items in each stratum (group of items with the same key) in the RDD, and
+   * the second pass uses the counts to determine exact sampling rates.
+   *
+   * The sampling function has a unique seed per partition.
+   */
+  def getPoissonSamplingFunction[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)],
+      fractions: Map[K, Double],
+      exact: Boolean,
+      seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = {
+    // TODO implement the streaming version of sampling w/ replacement that doesn't require counts
+    if (exact) {
+      val counts = Some(rdd.countByKey())
+      val finalResult = getAcceptanceResults(rdd, true, fractions, counts, seed)
+      val thresholdByKey = computeThresholdByKey(finalResult, fractions)
+      (idx: Int, iter: Iterator[(K, V)]) => {
+        val rng = new RandomDataGenerator()
+        rng.reSeed(seed + idx)
+        iter.flatMap { item =>
+          val key = item._1
+          val acceptBound = finalResult(key).acceptBound
+          // Must use the same invoke pattern on the rng as in getSeqOp for with replacement
+          // in order to generate the same sequence of random numbers when creating the sample
+          val copiesAccepted = if (acceptBound == 0) 0L else rng.nextPoisson(acceptBound)
+          val copiesWailisted = rng.nextPoisson(finalResult(key).waitListBound)
+          val copiesInSample = copiesAccepted +
+            (0 until copiesWailisted).count(i => rng.nextUniform() < thresholdByKey(key))
+          if (copiesInSample > 0) {
+            Iterator.fill(copiesInSample.toInt)(item)
+          } else {
+            Iterator.empty
+          }
+        }
+      }
+    } else {
+      (idx: Int, iter: Iterator[(K, V)]) => {
+        val rng = new RandomDataGenerator()
+        rng.reSeed(seed + idx)
+        iter.flatMap { item =>
+          val count = rng.nextPoisson(fractions(item._1))
+          if (count > 0) {
+            Iterator.fill(count)(item)
+          } else {
+            Iterator.empty
+          }
+        }
+      }
+    }
+  }
+
+  /** A random data generator that generates both uniform values and Poisson values. */
+  private class RandomDataGenerator {
+    val uniform = new XORShiftRandom()
+    var poisson = new Poisson(1.0, new DRand)
+
+    def reSeed(seed: Long) {
+      uniform.setSeed(seed)
+      poisson = new Poisson(1.0, new DRand(seed.toInt))
+    }
+
+    def nextPoisson(mean: Double): Int = {
+      poisson.nextInt(mean)
+    }
+
+    def nextUniform(): Double = {
+      uniform.nextDouble()
+    }
+  }
+}
+
+/**
+ * Object used by seqOp to keep track of the number of items accepted and items waitlisted per
+ * stratum, as well as the bounds for accepting and waitlisting items.
+ *
+ * `[random]` here is necessary since it's in the return type signature of seqOp defined above
+ */
+private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted: Long = 0L)
+  extends Serializable {
+
+  val waitList = new ArrayBuffer[Double]
+  var acceptBound: Double = Double.NaN // upper bound for accepting item instantly
+  var waitListBound: Double = Double.NaN // upper bound for adding item to waitlist
+
+  def areBoundsEmpty = acceptBound.isNaN || waitListBound.isNaN
+
+  def merge(other: Option[AcceptanceResult]): Unit = {
+    if (other.isDefined) {
+      waitList ++= other.get.waitList
+      numAccepted += other.get.numAccepted
+      numItems += other.get.numItems
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dc965364/core/src/test/java/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index f882a86..e8bd65f 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -29,6 +29,7 @@ import scala.Tuple4;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Iterators;
 import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 import com.google.common.base.Optional;
 import com.google.common.base.Charsets;
 import com.google.common.io.Files;
@@ -1208,4 +1209,40 @@ public class JavaAPISuite implements Serializable {
     pairRDD.collect();  // Works fine
     pairRDD.collectAsMap();  // Used to crash with ClassCastException
   }
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void sampleByKey() {
+    JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3);
+    JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(
+      new PairFunction<Integer, Integer, Integer>() {
+        @Override
+        public Tuple2<Integer, Integer> call(Integer i) {
+          return new Tuple2<Integer, Integer>(i % 2, 1);
+        }
+      });
+    Map<Integer, Object> fractions = Maps.newHashMap();
+    fractions.put(0, 0.5);
+    fractions.put(1, 1.0);
+    JavaPairRDD<Integer, Integer> wr = rdd2.sampleByKey(true, fractions, 1L);
+    Map<Integer, Long> wrCounts = (Map<Integer, Long>) (Object) wr.countByKey();
+    Assert.assertTrue(wrCounts.size() == 2);
+    Assert.assertTrue(wrCounts.get(0) > 0);
+    Assert.assertTrue(wrCounts.get(1) > 0);
+    JavaPairRDD<Integer, Integer> wor = rdd2.sampleByKey(false, fractions, 1L);
+    Map<Integer, Long> worCounts = (Map<Integer, Long>) (Object) wor.countByKey();
+    Assert.assertTrue(worCounts.size() == 2);
+    Assert.assertTrue(worCounts.get(0) > 0);
+    Assert.assertTrue(worCounts.get(1) > 0);
+    JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKey(true, fractions, true, 1L);
+    Map<Integer, Long> wrExactCounts = (Map<Integer, Long>) (Object) wrExact.countByKey();
+    Assert.assertTrue(wrExactCounts.size() == 2);
+    Assert.assertTrue(wrExactCounts.get(0) == 2);
+    Assert.assertTrue(wrExactCounts.get(1) == 4);
+    JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKey(false, fractions, true, 1L);
+    Map<Integer, Long> worExactCounts = (Map<Integer, Long>) (Object) worExact.countByKey();
+    Assert.assertTrue(worExactCounts.size() == 2);
+    Assert.assertTrue(worExactCounts.get(0) == 2);
+    Assert.assertTrue(worExactCounts.get(1) == 4);
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc965364/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 447e38e..4f49d4a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -83,6 +83,122 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
     assert(valuesFor2.toList.sorted === List(1))
   }
 
+  test("sampleByKey") {
+    def stratifier (fractionPositive: Double) = {
+      (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0"
+    }
+
+    def checkSize(exact: Boolean,
+        withReplacement: Boolean,
+        expected: Long,
+        actual: Long,
+        p: Double): Boolean = {
+      if (exact) {
+        return expected == actual
+      }
+      val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p))
+      // Very forgiving margin since we're dealing with very small sample sizes most of the time
+      math.abs(actual - expected) <= 6 * stdev
+    }
+
+    // Without replacement validation
+    def takeSampleAndValidateBernoulli(stratifiedData: RDD[(String, Int)],
+        exact: Boolean,
+        samplingRate: Double,
+        seed: Long,
+        n: Long) = {
+      val expectedSampleSize = stratifiedData.countByKey()
+        .mapValues(count => math.ceil(count * samplingRate).toInt)
+      val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
+      val sample = stratifiedData.sampleByKey(false, fractions, exact, seed)
+      val sampleCounts = sample.countByKey()
+      val takeSample = sample.collect()
+      sampleCounts.foreach { case(k, v) =>
+        assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) }
+      assert(takeSample.size === takeSample.toSet.size)
+      takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") }
+    }
+
+    // With replacement validation
+    def takeSampleAndValidatePoisson(stratifiedData: RDD[(String, Int)],
+        exact: Boolean,
+        samplingRate: Double,
+        seed: Long,
+        n: Long) = {
+      val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
+        math.ceil(count * samplingRate).toInt)
+      val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
+      val sample = stratifiedData.sampleByKey(true, fractions, exact, seed)
+      val sampleCounts = sample.countByKey()
+      val takeSample = sample.collect()
+      sampleCounts.foreach { case(k, v) =>
+        assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) }
+      val groupedByKey = takeSample.groupBy(_._1)
+      for ((key, v) <- groupedByKey) {
+        if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) {
+          // sample large enough for there to be repeats with high likelihood
+          assert(v.toSet.size < expectedSampleSize(key))
+        } else {
+          if (exact) {
+            assert(v.toSet.size <= expectedSampleSize(key))
+          } else {
+            assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate))
+          }
+        }
+      }
+      takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") }
+    }
+
+    def checkAllCombos(stratifiedData: RDD[(String, Int)],
+        samplingRate: Double,
+        seed: Long,
+        n: Long) = {
+      takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n)
+      takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n)
+      takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n)
+      takeSampleAndValidatePoisson(stratifiedData, false, samplingRate, seed, n)
+    }
+
+    val defaultSeed = 1L
+
+    // vary RDD size
+    for (n <- List(100, 1000, 1000000)) {
+      val data = sc.parallelize(1 to n, 2)
+      val fractionPositive = 0.3
+      val stratifiedData = data.keyBy(stratifier(fractionPositive))
+
+      val samplingRate = 0.1
+      checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+    }
+
+    // vary fractionPositive
+    for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) {
+      val n = 100
+      val data = sc.parallelize(1 to n, 2)
+      val stratifiedData = data.keyBy(stratifier(fractionPositive))
+
+      val samplingRate = 0.1
+      checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+    }
+
+    // Use the same data for the rest of the tests
+    val fractionPositive = 0.3
+    val n = 100
+    val data = sc.parallelize(1 to n, 2)
+    val stratifiedData = data.keyBy(stratifier(fractionPositive))
+
+    // vary seed
+    for (seed <- defaultSeed to defaultSeed + 5L) {
+      val samplingRate = 0.1
+      checkAllCombos(stratifiedData, samplingRate, seed, n)
+    }
+
+    // vary sampling rate
+    for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) {
+      checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+    }
+  }
+
   test("reduceByKey") {
     val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
     val sums = pairs.reduceByKey(_+_).collect()

http://git-wip-us.apache.org/repos/asf/spark/blob/dc965364/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 8b1435c..39538f9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -258,6 +258,12 @@
         <version>1.5</version>
       </dependency>
       <dependency>
+        <groupId>org.apache.commons</groupId>
+        <artifactId>commons-math3</artifactId>
+        <version>3.3</version>
+        <scope>test</scope>
+      </dependency>
+      <dependency>
         <groupId>com.google.code.findbugs</groupId>
         <artifactId>jsr305</artifactId>
         <version>1.3.9</version>