You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/12 10:50:22 UTC
spark git commit: [SPARK-4355][MLLIB] fix OnlineSummarizer.merge when
other.mean is zero
Repository: spark
Updated Branches:
refs/heads/master faeb41de2 -> 84324fbcb
[SPARK-4355][MLLIB] fix OnlineSummarizer.merge when other.mean is zero
See inline comment about the bug. I also did some code clean-up. dbtsai I moved `update` to a private method of `MultivariateOnlineSummarizer`. I don't think it will cause performance regression, but it would be great if you have some time to test.
Author: Xiangrui Meng <me...@databricks.com>
Closes #3220 from mengxr/SPARK-4355 and squashes the following commits:
5ef601f [Xiangrui Meng] fix OnlineSummarizer.merge when other.mean is zero and some code clean-up
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84324fbc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84324fbc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84324fbc
Branch: refs/heads/master
Commit: 84324fbcb987db6e10e435f463eacace1bae43e2
Parents: faeb41d
Author: Xiangrui Meng <me...@databricks.com>
Authored: Wed Nov 12 01:50:11 2014 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Nov 12 01:50:11 2014 -0800
----------------------------------------------------------------------
.../stat/MultivariateOnlineSummarizer.scala | 85 +++++++++-----------
.../MultivariateOnlineSummarizerSuite.scala | 11 +++
2 files changed, 51 insertions(+), 45 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/84324fbc/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index fab7c44..654479a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -50,6 +50,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var currMin: BDV[Double] = _
/**
+ * Adds input value to position i.
+ */
+ private[this] def add(i: Int, value: Double) = {
+ if (value != 0.0) {
+ if (currMax(i) < value) {
+ currMax(i) = value
+ }
+ if (currMin(i) > value) {
+ currMin(i) = value
+ }
+
+ val prevMean = currMean(i)
+ val diff = value - prevMean
+ currMean(i) = prevMean + diff / (nnz(i) + 1.0)
+ currM2n(i) += (value - currMean(i)) * diff
+ currM2(i) += value * value
+ currL1(i) += math.abs(value)
+
+ nnz(i) += 1.0
+ }
+ }
+
+ /**
* Add a new sample to this summarizer, and update the statistical summary.
*
* @param sample The sample in dense/sparse vector format to be added into this summarizer.
@@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
- @inline def update(i: Int, value: Double) = {
- if (value != 0.0) {
- if (currMax(i) < value) {
- currMax(i) = value
- }
- if (currMin(i) > value) {
- currMin(i) = value
- }
-
- val tmpPrevMean = currMean(i)
- currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
- currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
- currM2(i) += value * value
- currL1(i) += math.abs(value)
-
- nnz(i) += 1.0
- }
- }
-
sample match {
case dv: DenseVector => {
var j = 0
while (j < dv.size) {
- update(j, dv.values(j))
+ add(j, dv.values(j))
j += 1
}
}
case sv: SparseVector =>
var j = 0
while (j < sv.indices.size) {
- update(sv.indices(j), sv.values(j))
+ add(sv.indices(j), sv.values(j))
j += 1
}
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
@@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
- val deltaMean: BDV[Double] = currMean - other.currMean
var i = 0
while (i < n) {
- // merge mean together
- if (other.currMean(i) != 0.0) {
- currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
- (nnz(i) + other.nnz(i))
- }
- // merge m2n together
- if (nnz(i) + other.nnz(i) != 0.0) {
- currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
- (nnz(i) + other.nnz(i))
- }
- // merge m2 together
- if (nnz(i) + other.nnz(i) != 0.0) {
+ val thisNnz = nnz(i)
+ val otherNnz = other.nnz(i)
+ val totalNnz = thisNnz + otherNnz
+ if (totalNnz != 0.0) {
+ val deltaMean = other.currMean(i) - currMean(i)
+ // merge mean together
+ currMean(i) += deltaMean * otherNnz / totalNnz
+ // merge m2n together
+ currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
+ // merge m2 together
currM2(i) += other.currM2(i)
- }
- // merge l1 together
- if (nnz(i) + other.nnz(i) != 0.0) {
+ // merge l1 together
currL1(i) += other.currL1(i)
+ // merge max and min
+ currMax(i) = math.max(currMax(i), other.currMax(i))
+ currMin(i) = math.min(currMin(i), other.currMin(i))
}
-
- if (currMax(i) < other.currMax(i)) {
- currMax(i) = other.currMax(i)
- }
- if (currMin(i) > other.currMin(i)) {
- currMin(i) = other.currMin(i)
- }
+ nnz(i) = totalNnz
i += 1
}
- nnz += other.nnz
} else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n
this.currMean = other.currMean.copy
http://git-wip-us.apache.org/repos/asf/spark/blob/84324fbc/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 1e94152..23b0eec 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
}
+
+ test("merging summarizer when one side has zero mean (SPARK-4355)") {
+ val s0 = new MultivariateOnlineSummarizer()
+ .add(Vectors.dense(2.0))
+ .add(Vectors.dense(2.0))
+ val s1 = new MultivariateOnlineSummarizer()
+ .add(Vectors.dense(1.0))
+ .add(Vectors.dense(-1.0))
+ s0.merge(s1)
+ assert(s0.mean(0) ~== 1.0 absTol 1e-14)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org