You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/22 06:22:54 UTC

svn commit: r999749 - in /mahout/trunk/core/src: main/java/org/apache/mahout/math/stats/OnlineAuc.java test/java/org/apache/mahout/math/stats/OnlineAucTest.java

Author: tdunning
Date: Wed Sep 22 04:22:54 2010
New Revision: 999749

URL: http://svn.apache.org/viewvc?rev=999749&view=rev
Log:
Tuned OnlineAuc window and policy for accuracy.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java?rev=999749&r1=999748&r2=999749&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java Wed Sep 22 04:22:54 2010
@@ -38,9 +38,13 @@ public class OnlineAuc {
     FIFO, FAIR, RANDOM
   }
 
-  public static final int HISTORY = 100;
+  // increasing this to 100 causes very small improvements in accuracy.  Decreasing it to 2
+  // causes substantial degradation for the FAIR and RANDOM policies, but almost no change
+  // for the FIFO policy
+  public static final int HISTORY = 10;
 
-  private ReplacementPolicy policy = ReplacementPolicy.FAIR;
+  // FIFO has distinctly the best properties as a policy.  See OnlineAucTest for details
+  private ReplacementPolicy policy = ReplacementPolicy.FIFO;
   private transient Random random = org.apache.mahout.common.RandomUtils.getRandom();
   private final Matrix scores;
   private final Vector averages;
@@ -83,36 +87,32 @@ public class OnlineAuc {
       // compare to previous scores for other category
       Vector row = scores.viewRow(1 - category);
       double m = 0.0;
-      int count = 0;
+      double count = 0.0;
       for (Vector.Element element : row) {
         double v = element.get();
         if (Double.isNaN(v)) {
-          break;
+          continue;
         }
         count++;
-        double z = 0;
         if (score > v) {
-          z = 1.0;
+          m++;
         } else if (score < v) {
-          z = 0.0;
+          // m += 0
+        } else if (score == v) {
+          m += 0.5;
         }
-        m += (z - m) / count;
       }
-      averages.set(category, averages.get(category) + (m - averages.get(category)) / samples.get(category));
+      averages.set(category, averages.get(category) + (m / count - averages.get(category)) / samples.get(category));
     }
     return auc();
   }
 
   public double auc() {
     // return an unweighted average of all averages.
-    return 0.5 - averages.get(0) / 2 + averages.get(1) / 2;
+    return (1 - averages.get(0) + averages.get(1)) / 2;
   }
 
   public void setPolicy(ReplacementPolicy policy) {
     this.policy = policy;
   }
-
-  public void setRandom(Random random) {
-    this.random = random;
-  }
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java?rev=999749&r1=999748&r2=999749&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java Wed Sep 22 04:22:54 2010
@@ -17,45 +17,81 @@
 
 package org.apache.mahout.math.stats;
 
+import org.apache.mahout.classifier.evaluation.Auc;
 import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
 import org.junit.Test;
 
 import java.util.Random;
 
+import static org.apache.mahout.math.stats.OnlineAuc.ReplacementPolicy.*;
+
 public final class OnlineAucTest extends MahoutTestCase {
 
   @Test
   public void testBinaryCase() {
-    OnlineAuc a1 = new OnlineAuc();
-    a1.setRandom(new Random(1));
-    a1.setPolicy(OnlineAuc.ReplacementPolicy.FAIR);
-
-    OnlineAuc a2 = new OnlineAuc();
-    a2.setRandom(new Random(2));
-    a2.setPolicy(OnlineAuc.ReplacementPolicy.FIFO);
-
-    OnlineAuc a3 = new OnlineAuc();
-    a3.setRandom(new Random(3));
-    a3.setPolicy(OnlineAuc.ReplacementPolicy.RANDOM);
-
-    Random gen = new Random(1);
-    for (int i = 0; i < 10000; i++) {
-      double x = gen.nextGaussian();
-
-      a1.addSample(1, x);
-      a2.addSample(1, x);
-      a3.addSample(1, x);
-
-      x = gen.nextGaussian() + 1;
-
-      a1.addSample(0, x);
-      a2.addSample(0, x);
-      a3.addSample(0, x);
+    Random gen = RandomUtils.getRandom();
+
+    OnlineSummarizer[] stats = new OnlineSummarizer[4];
+    for (int i = 0; i < 4; i++) {
+      stats[i] = new OnlineSummarizer();
+    }
+
+    for (int i = 0; i < 500; i++) {
+      OnlineAuc a1 = new OnlineAuc();
+      a1.setPolicy(FAIR);
+
+      OnlineAuc a2 = new OnlineAuc();
+      a2.setPolicy(FIFO);
+
+      OnlineAuc a3 = new OnlineAuc();
+      a3.setPolicy(RANDOM);
+
+      Auc a4 = new Auc();
+
+      for (int j = 0; j < 10000; j++) {
+        double x = gen.nextGaussian();
+
+        a1.addSample(0, x);
+        a2.addSample(0, x);
+        a3.addSample(0, x);
+        a4.add(0, x);
+
+        x = gen.nextGaussian() + 1;
+
+        a1.addSample(1, x);
+        a2.addSample(1, x);
+        a3.addSample(1, x);
+        a4.add(1, x);
+      }
+
+      stats[0].add(a1.auc());
+      stats[1].add(a2.auc());
+      stats[2].add(a3.auc());
+      stats[3].add(a4.auc());
+    }
+    
+    int i = 0;
+    for (OnlineAuc.ReplacementPolicy policy : new OnlineAuc.ReplacementPolicy[]{FAIR, FIFO, RANDOM, null}) {
+      OnlineSummarizer summary = stats[i++];
+      System.out.printf("%s,%.4f (min = %.4f, 25%%-ile=%.4f, 75%%-ile=%.4f, max=%.4f)\n", policy, summary.getMean(),
+        summary.getQuartile(0), summary.getQuartile(1), summary.getQuartile(2), summary.getQuartile(3));
+
     }
 
-    // reference value computed using R: mean(rnorm(1000000) < rnorm(1000000,1))
-    assertEquals(1 - 0.76, a1.auc(), 0.05);
-    assertEquals(1 - 0.76, a2.auc(), 0.05);
-    assertEquals(1 - 0.76, a3.auc(), 0.05);
+    // FAIR policy isn't so accurate
+    assertEquals(0.7603, stats[0].getMean(), 0.03);
+    assertEquals(0.7603, stats[0].getQuartile(1), 0.025);
+    assertEquals(0.7603, stats[0].getQuartile(3), 0.025);
+
+    // FIFO policy seems best
+    assertEquals(0.7603, stats[1].getMean(), 0.001);
+    assertEquals(0.7603, stats[1].getQuartile(1), 0.006);
+    assertEquals(0.7603, stats[1].getQuartile(3), 0.006);
+
+    // RANDOM policy is nearly the same as FIFO
+    assertEquals(0.7603, stats[2].getMean(), 0.001);
+    assertEquals(0.7603, stats[2].getQuartile(1), 0.006);
+    assertEquals(0.7603, stats[2].getQuartile(1), 0.006);
   }
 }