You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/18 21:41:54 UTC
git commit: Reservoir sampling implementation.
Repository: spark
Updated Branches:
refs/heads/master 7f87ab981 -> 586e716e4
Reservoir sampling implementation.
This is going to be used in https://issues.apache.org/jira/browse/SPARK-2568
Author: Reynold Xin <rx...@apache.org>
Closes #1478 from rxin/reservoirSample and squashes the following commits:
17bcbf3 [Reynold Xin] Added seed.
badf20d [Reynold Xin] Renamed the method.
6940010 [Reynold Xin] Reservoir sampling implementation.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/586e716e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/586e716e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/586e716e
Branch: refs/heads/master
Commit: 586e716e47305cd7c2c3ff35c0e828b63ef2f6a8
Parents: 7f87ab9
Author: Reynold Xin <rx...@apache.org>
Authored: Fri Jul 18 12:41:50 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Fri Jul 18 12:41:50 2014 -0700
----------------------------------------------------------------------
.../spark/util/random/SamplingUtils.scala | 46 ++++++++++++++++++++
.../spark/util/random/SamplingUtilsSuite.scala | 21 +++++++++
2 files changed, 67 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/586e716e/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index a79e3ee..d10141b 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -17,9 +17,55 @@
package org.apache.spark.util.random
+import scala.reflect.ClassTag
+import scala.util.Random
+
private[spark] object SamplingUtils {
/**
+ * Reservoir sampling implementation that also returns the input size.
+ *
+ * @param input input size
+ * @param k reservoir size
+ * @param seed random seed
+ * @return (samples, input size)
+ */
+ def reservoirSampleAndCount[T: ClassTag](
+ input: Iterator[T],
+ k: Int,
+ seed: Long = Random.nextLong())
+ : (Array[T], Int) = {
+ val reservoir = new Array[T](k)
+ // Put the first k elements in the reservoir.
+ var i = 0
+ while (i < k && input.hasNext) {
+ val item = input.next()
+ reservoir(i) = item
+ i += 1
+ }
+
+ // If we have consumed all the elements, return them. Otherwise do the replacement.
+ if (i < k) {
+ // If input size < k, trim the array to return only an array of input size.
+ val trimReservoir = new Array[T](i)
+ System.arraycopy(reservoir, 0, trimReservoir, 0, i)
+ (trimReservoir, i)
+ } else {
+ // If input size > k, continue the sampling process.
+ val rand = new XORShiftRandom(seed)
+ while (input.hasNext) {
+ val item = input.next()
+ val replacementIndex = rand.nextInt(i)
+ if (replacementIndex < k) {
+ reservoir(replacementIndex) = item
+ }
+ i += 1
+ }
+ (reservoir, i)
+ }
+ }
+
+ /**
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
* the time.
*
http://git-wip-us.apache.org/repos/asf/spark/blob/586e716e/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
index accfe2e..73a9d02 100644
--- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -17,11 +17,32 @@
package org.apache.spark.util.random
+import scala.util.Random
+
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
import org.scalatest.FunSuite
class SamplingUtilsSuite extends FunSuite {
+ test("reservoirSampleAndCount") {
+ val input = Seq.fill(100)(Random.nextInt())
+
+ // input size < k
+ val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150)
+ assert(count1 === 100)
+ assert(input === sample1.toSeq)
+
+ // input size == k
+ val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100)
+ assert(count2 === 100)
+ assert(input === sample2.toSeq)
+
+ // input size > k
+ val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10)
+ assert(count3 === 100)
+ assert(sample3.length === 10)
+ }
+
test("computeFraction") {
// test that the computed fraction guarantees enough data points
// in the sample with a failure rate <= 0.0001