You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2017/08/28 06:41:49 UTC

spark git commit: [SPARK-21818][ML][MLLIB] Fix bug of MultivariateOnlineSummarizer.variance generate negative result

Repository: spark
Updated Branches:
  refs/heads/master 07142cf6d -> 0456b4050


[SPARK-21818][ML][MLLIB] Fix bug of MultivariateOnlineSummarizer.variance generate negative result

## What changes were proposed in this pull request?

Because of numerical error, MultivariateOnlineSummarizer.variance is possible to generate negative variance.

**This is a serious bug because many algos in MLLib**
**use stddev computed from** `sqrt(variance)`
**it will generate NaN and crash the whole algorithm.**

we can reproduce this bug use the following code:
```
    val summarizer1 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.7)
    val summarizer2 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.4)
    val summarizer3 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.5)
    val summarizer4 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.4)

    val summarizer = summarizer1
      .merge(summarizer2)
      .merge(summarizer3)
      .merge(summarizer4)

    println(summarizer.variance(0))
```
This PR fix the bugs in `mllib.stat.MultivariateOnlineSummarizer.variance` and `ml.stat.SummarizerBuffer.variance`, and several places in `WeightedLeastSquares`

## How was this patch tested?

test cases added.

Author: WeichenXu <We...@outlook.com>

Closes #19029 from WeichenXu123/fix_summarizer_var_bug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0456b405
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0456b405
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0456b405

Branch: refs/heads/master
Commit: 0456b4050817e64f27824720e695bbfff738d474
Parents: 07142cf
Author: WeichenXu <We...@outlook.com>
Authored: Mon Aug 28 07:41:42 2017 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Mon Aug 28 07:41:42 2017 +0100

----------------------------------------------------------------------
 .../spark/ml/optim/WeightedLeastSquares.scala     | 12 +++++++++---
 .../org/apache/spark/ml/stat/Summarizer.scala     |  5 +++--
 .../mllib/stat/MultivariateOnlineSummarizer.scala |  5 +++--
 .../apache/spark/ml/stat/SummarizerSuite.scala    | 18 ++++++++++++++++++
 .../stat/MultivariateOnlineSummarizerSuite.scala  | 18 ++++++++++++++++++
 5 files changed, 51 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0456b405/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 32b0af7..1ed218a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -440,7 +440,11 @@ private[ml] object WeightedLeastSquares {
     /**
      * Weighted population standard deviation of labels.
      */
-    def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
+    def bStd: Double = {
+      // We prevent variance from negative value caused by numerical error.
+      val variance = math.max(bbSum / wSum - bBar * bBar, 0.0)
+      math.sqrt(variance)
+    }
 
     /**
      * Weighted mean of (label * features).
@@ -471,7 +475,8 @@ private[ml] object WeightedLeastSquares {
       while (i < triK) {
         val l = j - 2
         val aw = aSum(l) / wSum
-        std(l) = math.sqrt(aaValues(i) / wSum - aw * aw)
+        // We prevent variance from negative value caused by numerical error.
+        std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0))
         i += j
         j += 1
       }
@@ -489,7 +494,8 @@ private[ml] object WeightedLeastSquares {
       while (i < triK) {
         val l = j - 2
         val aw = aSum(l) / wSum
-        variance(l) = aaValues(i) / wSum - aw * aw
+        // We prevent variance from negative value caused by numerical error.
+        variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0)
         i += j
         j += 1
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/0456b405/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
index 7e408b9..cae41ed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
@@ -436,8 +436,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
         var i = 0
         val len = currM2n.length
         while (i < len) {
-          realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
-            (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
+          // We prevent variance from negative value caused by numerical error.
+          realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
+            (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
           i += 1
         }
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/0456b405/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 7dc0c45..8121880 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
       var i = 0
       val len = currM2n.length
       while (i < len) {
-        realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
-          (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
+        // We prevent variance from negative value caused by numerical error.
+        realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
+          (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
         i += 1
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/0456b405/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
index dfb733f..1ea851e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
@@ -402,6 +402,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(summarizer.count === 6)
   }
 
+  test("summarizer buffer zero variance test (SPARK-21818)") {
+    val summarizer1 = new SummarizerBuffer()
+      .add(Vectors.dense(3.0), 0.7)
+    val summarizer2 = new SummarizerBuffer()
+      .add(Vectors.dense(3.0), 0.4)
+    val summarizer3 = new SummarizerBuffer()
+      .add(Vectors.dense(3.0), 0.5)
+    val summarizer4 = new SummarizerBuffer()
+      .add(Vectors.dense(3.0), 0.4)
+
+    val summarizer = summarizer1
+      .merge(summarizer2)
+      .merge(summarizer3)
+      .merge(summarizer4)
+
+    assert(summarizer.variance(0) >= 0.0)
+  }
+
   test("summarizer buffer merging summarizer with empty summarizer") {
     // If one of two is non-empty, this should return the non-empty summarizer.
     // If both of them are empty, then just return the empty summarizer.

http://git-wip-us.apache.org/repos/asf/spark/blob/0456b405/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 797e84f..c6466bc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
     assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
     assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
   }
+
+  test ("test zero variance (SPARK-21818)") {
+    val summarizer1 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.dense(3.0), 0.7)
+    val summarizer2 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.dense(3.0), 0.4)
+    val summarizer3 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.dense(3.0), 0.5)
+    val summarizer4 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.dense(3.0), 0.4)
+
+    val summarizer = summarizer1
+      .merge(summarizer2)
+      .merge(summarizer3)
+      .merge(summarizer4)
+
+    assert(summarizer.variance(0) >= 0.0)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org