You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/12/07 09:34:53 UTC

spark git commit: [SPARK-18678][ML] Skewed reservoir sampling in SamplingUtils

Repository: spark
Updated Branches:
  refs/heads/master b82802713 -> 79f5f281b


[SPARK-18678][ML] Skewed reservoir sampling in SamplingUtils

## What changes were proposed in this pull request?

Fix reservoir sampling bias for small k. An off-by-one error meant that the probability of replacement was slightly too high -- k/(l-1) after l element instead of k/l, which matters for small k.

## How was this patch tested?

Existing test plus new test case.

Author: Sean Owen <so...@cloudera.com>

Closes #16129 from srowen/SPARK-18678.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/79f5f281
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/79f5f281
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/79f5f281

Branch: refs/heads/master
Commit: 79f5f281bb69cb2de9f64006180abd753e8ae427
Parents: b828027
Author: Sean Owen <so...@cloudera.com>
Authored: Wed Dec 7 17:34:45 2016 +0800
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Dec 7 17:34:45 2016 +0800

----------------------------------------------------------------------
 R/pkg/inst/tests/testthat/test_mllib.R                 |  9 +++++----
 .../org/apache/spark/util/random/SamplingUtils.scala   |  5 ++++-
 .../apache/spark/util/random/SamplingUtilsSuite.scala  | 13 +++++++++++++
 3 files changed, 22 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/79f5f281/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 0802a2a..4758e40 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -1007,10 +1007,11 @@ test_that("spark.randomForest", {
   model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
                               numTrees = 20, seed = 123)
   predictions <- collect(predict(model, data))
-  expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258,
-                                         63.736, 64.296, 64.868, 64.300,
-                                         66.709, 67.697, 67.966, 67.252,
-                                         68.866, 69.593, 69.195, 69.658),
+  expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070,
+                                         63.53160, 64.05470, 65.12710, 64.30450,
+                                         66.70910, 67.86125, 68.08700, 67.21865,
+                                         68.89275, 69.53180, 69.39640, 69.68250),
+
                tolerance = 1e-4)
   stats <- summary(model)
   expect_equal(stats$numTrees, 20)

http://git-wip-us.apache.org/repos/asf/spark/blob/79f5f281/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index 297524c..a7e0075 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -56,11 +56,14 @@ private[spark] object SamplingUtils {
       val rand = new XORShiftRandom(seed)
       while (input.hasNext) {
         val item = input.next()
+        l += 1
+        // There are k elements in the reservoir, and the l-th element has been
+        // consumed. It should be chosen with probability k/l. The expression
+        // below is a random long chosen uniformly from [0,l)
         val replacementIndex = (rand.nextDouble() * l).toLong
         if (replacementIndex < k) {
           reservoir(replacementIndex.toInt) = item
         }
-        l += 1
       }
       (reservoir, l)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/79f5f281/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
index 667a4db..55c5dd5 100644
--- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -44,6 +44,19 @@ class SamplingUtilsSuite extends SparkFunSuite {
     assert(sample3.length === 10)
   }
 
+  test("SPARK-18678 reservoirSampleAndCount with tiny input") {
+    val input = Seq(0, 1)
+    val counts = new Array[Int](input.size)
+    for (i <- 0 until 500) {
+      val (samples, inputSize) = SamplingUtils.reservoirSampleAndCount(input.iterator, 1)
+      assert(inputSize === 2)
+      assert(samples.length === 1)
+      counts(samples.head) += 1
+    }
+    // If correct, should be true with prob ~ 0.99999707
+    assert(math.abs(counts(0) - counts(1)) <= 100)
+  }
+
   test("computeFraction") {
     // test that the computed fraction guarantees enough data points
     // in the sample with a failure rate <= 0.0001


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org