You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/12/07 09:34:53 UTC
spark git commit: [SPARK-18678][ML] Skewed reservoir sampling in
SamplingUtils
Repository: spark
Updated Branches:
refs/heads/master b82802713 -> 79f5f281b
[SPARK-18678][ML] Skewed reservoir sampling in SamplingUtils
## What changes were proposed in this pull request?
Fix reservoir sampling bias for small k. An off-by-one error meant that the probability of replacement was slightly too high -- k/(l-1) after l element instead of k/l, which matters for small k.
## How was this patch tested?
Existing test plus new test case.
Author: Sean Owen <so...@cloudera.com>
Closes #16129 from srowen/SPARK-18678.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/79f5f281
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/79f5f281
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/79f5f281
Branch: refs/heads/master
Commit: 79f5f281bb69cb2de9f64006180abd753e8ae427
Parents: b828027
Author: Sean Owen <so...@cloudera.com>
Authored: Wed Dec 7 17:34:45 2016 +0800
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Dec 7 17:34:45 2016 +0800
----------------------------------------------------------------------
R/pkg/inst/tests/testthat/test_mllib.R | 9 +++++----
.../org/apache/spark/util/random/SamplingUtils.scala | 5 ++++-
.../apache/spark/util/random/SamplingUtilsSuite.scala | 13 +++++++++++++
3 files changed, 22 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/79f5f281/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 0802a2a..4758e40 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -1007,10 +1007,11 @@ test_that("spark.randomForest", {
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
numTrees = 20, seed = 123)
predictions <- collect(predict(model, data))
- expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258,
- 63.736, 64.296, 64.868, 64.300,
- 66.709, 67.697, 67.966, 67.252,
- 68.866, 69.593, 69.195, 69.658),
+ expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070,
+ 63.53160, 64.05470, 65.12710, 64.30450,
+ 66.70910, 67.86125, 68.08700, 67.21865,
+ 68.89275, 69.53180, 69.39640, 69.68250),
+
tolerance = 1e-4)
stats <- summary(model)
expect_equal(stats$numTrees, 20)
http://git-wip-us.apache.org/repos/asf/spark/blob/79f5f281/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index 297524c..a7e0075 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -56,11 +56,14 @@ private[spark] object SamplingUtils {
val rand = new XORShiftRandom(seed)
while (input.hasNext) {
val item = input.next()
+ l += 1
+ // There are k elements in the reservoir, and the l-th element has been
+ // consumed. It should be chosen with probability k/l. The expression
+ // below is a random long chosen uniformly from [0,l)
val replacementIndex = (rand.nextDouble() * l).toLong
if (replacementIndex < k) {
reservoir(replacementIndex.toInt) = item
}
- l += 1
}
(reservoir, l)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/79f5f281/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
index 667a4db..55c5dd5 100644
--- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -44,6 +44,19 @@ class SamplingUtilsSuite extends SparkFunSuite {
assert(sample3.length === 10)
}
+ test("SPARK-18678 reservoirSampleAndCount with tiny input") {
+ val input = Seq(0, 1)
+ val counts = new Array[Int](input.size)
+ for (i <- 0 until 500) {
+ val (samples, inputSize) = SamplingUtils.reservoirSampleAndCount(input.iterator, 1)
+ assert(inputSize === 2)
+ assert(samples.length === 1)
+ counts(samples.head) += 1
+ }
+ // If correct, should be true with prob ~ 0.99999707
+ assert(math.abs(counts(0) - counts(1)) <= 100)
+ }
+
test("computeFraction") {
// test that the computed fraction guarantees enough data points
// in the sample with a failure rate <= 0.0001
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org