You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/11 02:30:32 UTC

svn commit: r996022 - in /mahout/trunk/core/src: main/java/org/apache/mahout/classifier/evaluation/Auc.java test/java/org/apache/mahout/classifier/evaluation/ test/java/org/apache/mahout/classifier/evaluation/AucTest.java

Author: tdunning
Date: Sat Sep 11 00:30:32 2010
New Revision: 996022

URL: http://svn.apache.org/viewvc?rev=996022&view=rev
Log:
Added test cases for Auc.  Correct bug in entropy matrix computation

Added:
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/evaluation/
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java?rev=996022&r1=996021&r2=996022&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java Sat Sep 11 00:30:32 2010
@@ -83,10 +83,15 @@ public class Auc {
 
     int predictedClass = (score > threshold) ? 1 : 0;
     confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1);
+
+    samples++;
     if (isProbabilityScore()) {
       double limited = Math.max(1.0e-20, Math.min(score, 1 - 1.0e-20));
-      entropy.set(trueValue, 0, Math.log(1 - limited));
-      entropy.set(trueValue, 1, Math.log(limited));
+      double v0 = entropy.get(trueValue, 0);
+      entropy.set(trueValue, 0, (Math.log(1 - limited) - v0) / samples + v0);
+
+      double v1 = entropy.get(trueValue, 1);
+      entropy.set(trueValue, 1, (Math.log(limited) - v1) / samples + v1);
     }
 
     // add to buffers
@@ -95,7 +100,6 @@ public class Auc {
       // but if too many points are seen, we insert into a random
       // place and discard the predecessor.  The random place could
       // be anywhere, possibly not even in the buffer.
-      samples++;
       // this is a special case of Knuth's permutation algorithm
       // but since we don't ever shuffle the first maxBufferSize
       // samples, the result isn't just a fair sample of the prefixes

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java?rev=996022&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java Sat Sep 11 00:30:32 2010
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.evaluation;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.jet.random.Normal;
+import org.junit.Test;
+
+import java.util.Random;
+
+public class AucTest extends MahoutTestCase{
+  @Test
+  public void testAuc() {
+    Auc auc = new Auc();
+    Random gen = RandomUtils.getRandom();
+    auc.setProbabilityScore(false);
+    for (int i=0;i<100000;i++) {
+      auc.add(0, gen.nextGaussian());
+      auc.add(1, gen.nextGaussian() + 1);
+    }
+    assertEquals(0.76, auc.auc(), 0.01);
+  }
+
+  @Test
+  public void testEntropy() {
+    Auc auc = new Auc();
+    Random gen = RandomUtils.getRandom();
+    Normal n0 = new Normal(-1, 1, gen);
+    Normal n1 = new Normal(1, 1, gen);
+    for (int i=0;i<100000;i++) {
+      double score = n0.nextDouble();
+      double p = n1.pdf(score) / (n0.pdf(score) + n1.pdf(score));
+      auc.add(0, p);
+
+      score = n1.nextDouble();
+      p = n1.pdf(score) / (n0.pdf(score) + n1.pdf(score));
+      auc.add(1, p);
+    }
+    Matrix m = auc.entropy();
+    assertEquals(-0.35, m.get(0, 0), 0.02);
+    assertEquals(-2.34, m.get(0, 1), 0.02);
+    assertEquals(-2.34, m.get(1, 0), 0.02);
+    assertEquals(-0.35, m.get(1, 1), 0.02);
+  }
+}