You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2017/08/28 06:41:49 UTC
spark git commit: [SPARK-21818][ML][MLLIB] Fix bug of
MultivariateOnlineSummarizer.variance generate negative result
Repository: spark
Updated Branches:
refs/heads/master 07142cf6d -> 0456b4050
[SPARK-21818][ML][MLLIB] Fix bug of MultivariateOnlineSummarizer.variance generate negative result
## What changes were proposed in this pull request?
Because of numerical error, MultivariateOnlineSummarizer.variance is possible to generate negative variance.
**This is a serious bug because many algos in MLLib**
**use stddev computed from** `sqrt(variance)`
**it will generate NaN and crash the whole algorithm.**
we can reproduce this bug use the following code:
```
val summarizer1 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.7)
val summarizer2 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.4)
val summarizer3 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.5)
val summarizer4 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.4)
val summarizer = summarizer1
.merge(summarizer2)
.merge(summarizer3)
.merge(summarizer4)
println(summarizer.variance(0))
```
This PR fix the bugs in `mllib.stat.MultivariateOnlineSummarizer.variance` and `ml.stat.SummarizerBuffer.variance`, and several places in `WeightedLeastSquares`
## How was this patch tested?
test cases added.
Author: WeichenXu <We...@outlook.com>
Closes #19029 from WeichenXu123/fix_summarizer_var_bug.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0456b405
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0456b405
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0456b405
Branch: refs/heads/master
Commit: 0456b4050817e64f27824720e695bbfff738d474
Parents: 07142cf
Author: WeichenXu <We...@outlook.com>
Authored: Mon Aug 28 07:41:42 2017 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Mon Aug 28 07:41:42 2017 +0100
----------------------------------------------------------------------
.../spark/ml/optim/WeightedLeastSquares.scala | 12 +++++++++---
.../org/apache/spark/ml/stat/Summarizer.scala | 5 +++--
.../mllib/stat/MultivariateOnlineSummarizer.scala | 5 +++--
.../apache/spark/ml/stat/SummarizerSuite.scala | 18 ++++++++++++++++++
.../stat/MultivariateOnlineSummarizerSuite.scala | 18 ++++++++++++++++++
5 files changed, 51 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/0456b405/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 32b0af7..1ed218a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -440,7 +440,11 @@ private[ml] object WeightedLeastSquares {
/**
* Weighted population standard deviation of labels.
*/
- def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
+ def bStd: Double = {
+ // We prevent variance from negative value caused by numerical error.
+ val variance = math.max(bbSum / wSum - bBar * bBar, 0.0)
+ math.sqrt(variance)
+ }
/**
* Weighted mean of (label * features).
@@ -471,7 +475,8 @@ private[ml] object WeightedLeastSquares {
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
- std(l) = math.sqrt(aaValues(i) / wSum - aw * aw)
+ // We prevent variance from negative value caused by numerical error.
+ std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0))
i += j
j += 1
}
@@ -489,7 +494,8 @@ private[ml] object WeightedLeastSquares {
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
- variance(l) = aaValues(i) / wSum - aw * aw
+ // We prevent variance from negative value caused by numerical error.
+ variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0)
i += j
j += 1
}
http://git-wip-us.apache.org/repos/asf/spark/blob/0456b405/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
index 7e408b9..cae41ed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
@@ -436,8 +436,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
var i = 0
val len = currM2n.length
while (i < len) {
- realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
- (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
+ // We prevent variance from negative value caused by numerical error.
+ realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
+ (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/0456b405/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 7dc0c45..8121880 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
var i = 0
val len = currM2n.length
while (i < len) {
- realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
- (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
+ // We prevent variance from negative value caused by numerical error.
+ realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
+ (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/0456b405/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
index dfb733f..1ea851e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
@@ -402,6 +402,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(summarizer.count === 6)
}
+ test("summarizer buffer zero variance test (SPARK-21818)") {
+ val summarizer1 = new SummarizerBuffer()
+ .add(Vectors.dense(3.0), 0.7)
+ val summarizer2 = new SummarizerBuffer()
+ .add(Vectors.dense(3.0), 0.4)
+ val summarizer3 = new SummarizerBuffer()
+ .add(Vectors.dense(3.0), 0.5)
+ val summarizer4 = new SummarizerBuffer()
+ .add(Vectors.dense(3.0), 0.4)
+
+ val summarizer = summarizer1
+ .merge(summarizer2)
+ .merge(summarizer3)
+ .merge(summarizer4)
+
+ assert(summarizer.variance(0) >= 0.0)
+ }
+
test("summarizer buffer merging summarizer with empty summarizer") {
// If one of two is non-empty, this should return the non-empty summarizer.
// If both of them are empty, then just return the empty summarizer.
http://git-wip-us.apache.org/repos/asf/spark/blob/0456b405/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 797e84f..c6466bc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
}
+
+ test ("test zero variance (SPARK-21818)") {
+ val summarizer1 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(3.0), 0.7)
+ val summarizer2 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(3.0), 0.4)
+ val summarizer3 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(3.0), 0.5)
+ val summarizer4 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(3.0), 0.4)
+
+ val summarizer = summarizer1
+ .merge(summarizer2)
+ .merge(summarizer3)
+ .merge(summarizer4)
+
+ assert(summarizer.variance(0) >= 0.0)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org