You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/22 06:22:54 UTC
svn commit: r999749 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/math/stats/OnlineAuc.java
test/java/org/apache/mahout/math/stats/OnlineAucTest.java
Author: tdunning
Date: Wed Sep 22 04:22:54 2010
New Revision: 999749
URL: http://svn.apache.org/viewvc?rev=999749&view=rev
Log:
Tuned OnlineAuc window and policy for accuracy.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java?rev=999749&r1=999748&r2=999749&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java Wed Sep 22 04:22:54 2010
@@ -38,9 +38,13 @@ public class OnlineAuc {
FIFO, FAIR, RANDOM
}
- public static final int HISTORY = 100;
+ // increasing this to 100 causes very small improvements in accuracy. Decreasing it to 2
+ // causes substantial degradation for the FAIR and RANDOM policies, but almost no change
+ // for the FIFO policy
+ public static final int HISTORY = 10;
- private ReplacementPolicy policy = ReplacementPolicy.FAIR;
+ // FIFO has distinctly the best properties as a policy. See OnlineAucTest for details
+ private ReplacementPolicy policy = ReplacementPolicy.FIFO;
private transient Random random = org.apache.mahout.common.RandomUtils.getRandom();
private final Matrix scores;
private final Vector averages;
@@ -83,36 +87,32 @@ public class OnlineAuc {
// compare to previous scores for other category
Vector row = scores.viewRow(1 - category);
double m = 0.0;
- int count = 0;
+ double count = 0.0;
for (Vector.Element element : row) {
double v = element.get();
if (Double.isNaN(v)) {
- break;
+ continue;
}
count++;
- double z = 0;
if (score > v) {
- z = 1.0;
+ m++;
} else if (score < v) {
- z = 0.0;
+ // m += 0
+ } else if (score == v) {
+ m += 0.5;
}
- m += (z - m) / count;
}
- averages.set(category, averages.get(category) + (m - averages.get(category)) / samples.get(category));
+ averages.set(category, averages.get(category) + (m / count - averages.get(category)) / samples.get(category));
}
return auc();
}
public double auc() {
// return an unweighted average of all averages.
- return 0.5 - averages.get(0) / 2 + averages.get(1) / 2;
+ return (1 - averages.get(0) + averages.get(1)) / 2;
}
public void setPolicy(ReplacementPolicy policy) {
this.policy = policy;
}
-
- public void setRandom(Random random) {
- this.random = random;
- }
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java?rev=999749&r1=999748&r2=999749&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java Wed Sep 22 04:22:54 2010
@@ -17,45 +17,81 @@
package org.apache.mahout.math.stats;
+import org.apache.mahout.classifier.evaluation.Auc;
import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
import org.junit.Test;
import java.util.Random;
+import static org.apache.mahout.math.stats.OnlineAuc.ReplacementPolicy.*;
+
public final class OnlineAucTest extends MahoutTestCase {
@Test
public void testBinaryCase() {
- OnlineAuc a1 = new OnlineAuc();
- a1.setRandom(new Random(1));
- a1.setPolicy(OnlineAuc.ReplacementPolicy.FAIR);
-
- OnlineAuc a2 = new OnlineAuc();
- a2.setRandom(new Random(2));
- a2.setPolicy(OnlineAuc.ReplacementPolicy.FIFO);
-
- OnlineAuc a3 = new OnlineAuc();
- a3.setRandom(new Random(3));
- a3.setPolicy(OnlineAuc.ReplacementPolicy.RANDOM);
-
- Random gen = new Random(1);
- for (int i = 0; i < 10000; i++) {
- double x = gen.nextGaussian();
-
- a1.addSample(1, x);
- a2.addSample(1, x);
- a3.addSample(1, x);
-
- x = gen.nextGaussian() + 1;
-
- a1.addSample(0, x);
- a2.addSample(0, x);
- a3.addSample(0, x);
+ Random gen = RandomUtils.getRandom();
+
+ OnlineSummarizer[] stats = new OnlineSummarizer[4];
+ for (int i = 0; i < 4; i++) {
+ stats[i] = new OnlineSummarizer();
+ }
+
+ for (int i = 0; i < 500; i++) {
+ OnlineAuc a1 = new OnlineAuc();
+ a1.setPolicy(FAIR);
+
+ OnlineAuc a2 = new OnlineAuc();
+ a2.setPolicy(FIFO);
+
+ OnlineAuc a3 = new OnlineAuc();
+ a3.setPolicy(RANDOM);
+
+ Auc a4 = new Auc();
+
+ for (int j = 0; j < 10000; j++) {
+ double x = gen.nextGaussian();
+
+ a1.addSample(0, x);
+ a2.addSample(0, x);
+ a3.addSample(0, x);
+ a4.add(0, x);
+
+ x = gen.nextGaussian() + 1;
+
+ a1.addSample(1, x);
+ a2.addSample(1, x);
+ a3.addSample(1, x);
+ a4.add(1, x);
+ }
+
+ stats[0].add(a1.auc());
+ stats[1].add(a2.auc());
+ stats[2].add(a3.auc());
+ stats[3].add(a4.auc());
+ }
+
+ int i = 0;
+ for (OnlineAuc.ReplacementPolicy policy : new OnlineAuc.ReplacementPolicy[]{FAIR, FIFO, RANDOM, null}) {
+ OnlineSummarizer summary = stats[i++];
+ System.out.printf("%s,%.4f (min = %.4f, 25%%-ile=%.4f, 75%%-ile=%.4f, max=%.4f)\n", policy, summary.getMean(),
+ summary.getQuartile(0), summary.getQuartile(1), summary.getQuartile(2), summary.getQuartile(3));
+
}
- // reference value computed using R: mean(rnorm(1000000) < rnorm(1000000,1))
- assertEquals(1 - 0.76, a1.auc(), 0.05);
- assertEquals(1 - 0.76, a2.auc(), 0.05);
- assertEquals(1 - 0.76, a3.auc(), 0.05);
+ // FAIR policy isn't so accurate
+ assertEquals(0.7603, stats[0].getMean(), 0.03);
+ assertEquals(0.7603, stats[0].getQuartile(1), 0.025);
+ assertEquals(0.7603, stats[0].getQuartile(3), 0.025);
+
+ // FIFO policy seems best
+ assertEquals(0.7603, stats[1].getMean(), 0.001);
+ assertEquals(0.7603, stats[1].getQuartile(1), 0.006);
+ assertEquals(0.7603, stats[1].getQuartile(3), 0.006);
+
+ // RANDOM policy is nearly the same as FIFO
+ assertEquals(0.7603, stats[2].getMean(), 0.001);
+ assertEquals(0.7603, stats[2].getQuartile(1), 0.006);
+ assertEquals(0.7603, stats[2].getQuartile(1), 0.006);
}
}