You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/10/31 21:36:48 UTC

svn commit: r1029489 - in /mahout/trunk/core/src: main/java/org/apache/mahout/cf/taste/impl/common/ test/java/org/apache/mahout/cf/taste/impl/common/

Author: srowen
Date: Sun Oct 31 20:36:48 2010
New Revision: 1029489

URL: http://svn.apache.org/viewvc?rev=1029489&view=rev
Log:
Use Welfords method for standard deviation for more accuracy

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/CompactRunningAverageAndStdDev.java
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java
    mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/CompactRunningAverageAndStdDev.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/CompactRunningAverageAndStdDev.java?rev=1029489&r1=1029488&r2=1029489&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/CompactRunningAverageAndStdDev.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/CompactRunningAverageAndStdDev.java Sun Oct 31 20:36:48 2010
@@ -20,37 +20,46 @@ package org.apache.mahout.cf.taste.impl.
 /**
  * <p>
  * Extends {@link CompactRunningAverage} to add a running standard deviation computation.
+ * Uses Welford's method, as described at http://www.johndcook.com/standard_deviation.html
  * </p>
  */
-public final class CompactRunningAverageAndStdDev extends CompactRunningAverage implements
-    RunningAverageAndStdDev {
-  
-  private float stdDev;
-  private float sumX2;
-  
-  public CompactRunningAverageAndStdDev() {
-    stdDev = Float.NaN;
-  }
-  
+public final class CompactRunningAverageAndStdDev extends CompactRunningAverage implements RunningAverageAndStdDev {
+
+  private float stdDev = Float.NaN;
+  private float mk;
+  private float sk;
+
   @Override
   public synchronized double getStandardDeviation() {
     return stdDev;
   }
-  
+
   @Override
   public synchronized void addDatum(double datum) {
     super.addDatum(datum);
-    sumX2 += (float) (datum * datum);
+    int count = getCount();
+    if (count == 1) {
+      mk = (float) datum;
+      sk = 0.0f;
+    } else {
+      float oldmk = mk;
+      float diff = (float) datum - oldmk;
+      mk += diff / count;
+      sk += diff * (datum - mk);
+    }
     recomputeStdDev();
   }
-  
+
   @Override
   public synchronized void removeDatum(double datum) {
+    int oldCount = getCount();
     super.removeDatum(datum);
-    sumX2 -= (float) (datum * datum);
+    float oldmk = mk;
+    mk = (oldCount * oldmk - (float) datum) / (oldCount - 1);
+    sk -= (datum - mk) * (datum - oldmk);
     recomputeStdDev();
   }
-  
+
   /**
    * @throws UnsupportedOperationException
    */
@@ -58,20 +67,19 @@ public final class CompactRunningAverage
   public void changeDatum(double delta) {
     throw new UnsupportedOperationException();
   }
-  
+
   private synchronized void recomputeStdDev() {
     int count = getCount();
     if (count > 1) {
-      double average = getAverage();
-      stdDev = (float) Math.sqrt((sumX2 - average * average * count) / (count - 1));
+      stdDev = (float) Math.sqrt(sk / (count - 1));
     } else {
       stdDev = Float.NaN;
     }
   }
-  
+
   @Override
   public synchronized String toString() {
     return String.valueOf(String.valueOf(getAverage()) + ',' + stdDev);
   }
-  
+
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java?rev=1029489&r1=1029488&r2=1029489&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java Sun Oct 31 20:36:48 2010
@@ -20,16 +20,14 @@ package org.apache.mahout.cf.taste.impl.
 /**
  * <p>
  * Extends {@link FullRunningAverage} to add a running standard deviation computation.
+ * Uses Welford's method, as described at http://www.johndcook.com/standard_deviation.html
  * </p>
  */
 public final class FullRunningAverageAndStdDev extends FullRunningAverage implements RunningAverageAndStdDev {
-  
-  private double stdDev;
-  private double sumX2;
-  
-  public FullRunningAverageAndStdDev() {
-    stdDev = Double.NaN;
-  }
+
+  private double stdDev = Double.NaN;
+  private double mk;
+  private double sk;
   
   @Override
   public synchronized double getStandardDeviation() {
@@ -39,14 +37,26 @@ public final class FullRunningAverageAnd
   @Override
   public synchronized void addDatum(double datum) {
     super.addDatum(datum);
-    sumX2 += datum * datum;
+    int count = getCount();
+    if (count == 1) {
+      mk = datum;
+      sk = 0.0;
+    } else {
+      double oldmk = mk;
+      double diff = datum - oldmk;
+      mk += diff / count;
+      sk += diff * (datum - mk);
+    }
     recomputeStdDev();
   }
   
   @Override
   public synchronized void removeDatum(double datum) {
+    int oldCount = getCount();
     super.removeDatum(datum);
-    sumX2 -= datum * datum;
+    double oldmk = mk;
+    mk = (oldCount * oldmk - datum) / (oldCount - 1);
+    sk -= (datum - mk) * (datum - oldmk);
     recomputeStdDev();
   }
   
@@ -61,8 +71,7 @@ public final class FullRunningAverageAnd
   private synchronized void recomputeStdDev() {
     int count = getCount();
     if (count > 1) {
-      double average = getAverage();
-      stdDev = Math.sqrt((sumX2 - average * average * count) / (count - 1));
+      stdDev = Math.sqrt(sk / (count - 1));
     } else {
       stdDev = Double.NaN;
     }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java?rev=1029489&r1=1029488&r2=1029489&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java Sun Oct 31 20:36:48 2010
@@ -18,10 +18,16 @@
 package org.apache.mahout.cf.taste.impl.common;
 
 import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.common.RandomUtils;
 import org.junit.Test;
 
+import java.util.Random;
+
 public final class RunningAverageAndStdDevTest extends TasteTestCase {
 
+  private static final double SMALL_EPSILON = 1.0;
+  private static final double BIG_EPSILON = 100 * SMALL_EPSILON;
+
   @Test
   public void testFull() {
     doTestAverageAndStdDev(new FullRunningAverageAndStdDev());
@@ -32,6 +38,16 @@ public final class RunningAverageAndStdD
     doTestAverageAndStdDev(new CompactRunningAverageAndStdDev());
   }
 
+  @Test
+  public void testFullBig() {
+    doTestBig(new FullRunningAverageAndStdDev(), SMALL_EPSILON);
+  }
+
+  @Test
+  public void testCompactBig() {
+    doTestBig(new CompactRunningAverageAndStdDev(), BIG_EPSILON);
+  }
+
   private static void doTestAverageAndStdDev(RunningAverageAndStdDev average) {
 
     assertEquals(0, average.getCount());
@@ -65,4 +81,15 @@ public final class RunningAverageAndStdD
 
   }
 
+  private static void doTestBig(RunningAverageAndStdDev average, double epsilon) {
+
+    Random r = RandomUtils.getRandom();
+    for (int i = 0; i < 100000; i++) {
+      average.addDatum(r.nextDouble() * 1000.0);
+    }
+    assertEquals(500.0, average.getAverage(), epsilon);
+    assertEquals(1000.0 / Math.sqrt(12.0), average.getStandardDeviation(), epsilon);
+
+  }
+
 }