You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/18 21:41:54 UTC

git commit: Reservoir sampling implementation.

Repository: spark
Updated Branches:
  refs/heads/master 7f87ab981 -> 586e716e4


Reservoir sampling implementation.

This is going to be used in https://issues.apache.org/jira/browse/SPARK-2568

Author: Reynold Xin <rx...@apache.org>

Closes #1478 from rxin/reservoirSample and squashes the following commits:

17bcbf3 [Reynold Xin] Added seed.
badf20d [Reynold Xin] Renamed the method.
6940010 [Reynold Xin] Reservoir sampling implementation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/586e716e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/586e716e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/586e716e

Branch: refs/heads/master
Commit: 586e716e47305cd7c2c3ff35c0e828b63ef2f6a8
Parents: 7f87ab9
Author: Reynold Xin <rx...@apache.org>
Authored: Fri Jul 18 12:41:50 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Fri Jul 18 12:41:50 2014 -0700

----------------------------------------------------------------------
 .../spark/util/random/SamplingUtils.scala       | 46 ++++++++++++++++++++
 .../spark/util/random/SamplingUtilsSuite.scala  | 21 +++++++++
 2 files changed, 67 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/586e716e/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index a79e3ee..d10141b 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -17,9 +17,55 @@
 
 package org.apache.spark.util.random
 
+import scala.reflect.ClassTag
+import scala.util.Random
+
 private[spark] object SamplingUtils {
 
   /**
+   * Reservoir sampling implementation that also returns the input size.
+   *
+   * @param input input size
+   * @param k reservoir size
+   * @param seed random seed
+   * @return (samples, input size)
+   */
+  def reservoirSampleAndCount[T: ClassTag](
+      input: Iterator[T],
+      k: Int,
+      seed: Long = Random.nextLong())
+    : (Array[T], Int) = {
+    val reservoir = new Array[T](k)
+    // Put the first k elements in the reservoir.
+    var i = 0
+    while (i < k && input.hasNext) {
+      val item = input.next()
+      reservoir(i) = item
+      i += 1
+    }
+
+    // If we have consumed all the elements, return them. Otherwise do the replacement.
+    if (i < k) {
+      // If input size < k, trim the array to return only an array of input size.
+      val trimReservoir = new Array[T](i)
+      System.arraycopy(reservoir, 0, trimReservoir, 0, i)
+      (trimReservoir, i)
+    } else {
+      // If input size > k, continue the sampling process.
+      val rand = new XORShiftRandom(seed)
+      while (input.hasNext) {
+        val item = input.next()
+        val replacementIndex = rand.nextInt(i)
+        if (replacementIndex < k) {
+          reservoir(replacementIndex) = item
+        }
+        i += 1
+      }
+      (reservoir, i)
+    }
+  }
+
+  /**
    * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
    * the time.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/586e716e/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
index accfe2e..73a9d02 100644
--- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -17,11 +17,32 @@
 
 package org.apache.spark.util.random
 
+import scala.util.Random
+
 import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
 import org.scalatest.FunSuite
 
 class SamplingUtilsSuite extends FunSuite {
 
+  test("reservoirSampleAndCount") {
+    val input = Seq.fill(100)(Random.nextInt())
+
+    // input size < k
+    val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150)
+    assert(count1 === 100)
+    assert(input === sample1.toSeq)
+
+    // input size == k
+    val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100)
+    assert(count2 === 100)
+    assert(input === sample2.toSeq)
+
+    // input size > k
+    val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10)
+    assert(count3 === 100)
+    assert(sample3.length === 10)
+  }
+
   test("computeFraction") {
     // test that the computed fraction guarantees enough data points
     // in the sample with a failure rate <= 0.0001