You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2013/06/01 22:44:20 UTC

svn commit: r1488595 - in /mahout/trunk/core/src: main/java/org/apache/mahout/classifier/ConfusionMatrix.java main/java/org/apache/mahout/classifier/ResultAnalyzer.java test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java

Author: robinanil
Date: Sat Jun  1 20:44:20 2013
New Revision: 1488595

URL: http://svn.apache.org/r1488595
Log:
MAHOUT-941 new confusion matrix statistics

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1488595&r1=1488594&r2=1488595&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java Sat Jun  1 20:44:20 2013
@@ -22,8 +22,11 @@ import java.util.Collections;
 import java.util.Map;
 
 import org.apache.commons.lang3.StringUtils;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
 import org.apache.mahout.math.DenseMatrix;
 import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.stats.OnlineSummarizer;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Maps;
@@ -38,6 +41,7 @@ import com.google.common.collect.Maps;
 public class ConfusionMatrix {
   private final Map<String,Integer> labelMap = Maps.newLinkedHashMap();
   private final int[][] confusionMatrix;
+  private int samples = 0;
   private String defaultLabel = "unknown";
   
   public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
@@ -70,12 +74,87 @@ public class ConfusionMatrix {
     for (int i = 0; i < labelMap.size(); i++) {
       labelTotal += confusionMatrix[labelId][i];
       if (i == labelId) {
-        correct = confusionMatrix[labelId][i];
+        correct += confusionMatrix[labelId][i];
       }
     }
     return 100.0 * correct / labelTotal;
   }
+
+  // Producer accuracy
+  public double getAccuracy() {
+    int total = 0;
+    int correct = 0;
+    for (int i = 0; i < labelMap.size(); i++) {
+      for (int j = 0; j < labelMap.size(); j++) {
+        total += confusionMatrix[i][j];
+        if (i == j) {
+          correct += confusionMatrix[i][j];
+        }
+      }
+    }
+    return 100.0 * correct / total;
+  }
   
+  // User accuracy 
+  public double getReliability() {
+    int count = 0;
+    double accuracy = 0;
+    for (String label: labelMap.keySet()) {
+      if (! label.equals(defaultLabel)) {
+        accuracy += getAccuracy(label);
+      }
+      count++;
+    }
+    return accuracy / count;
+  }
+  
+  /**
+   * Accuracy v.s. randomly classifying all samples.
+   * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy())
+   * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales. 
+   * Educational And Psychological Measurement 20:37-46.
+   * 
+   * Formula and variable names from:
+   * http://www.yale.edu/ceo/OEFS/Accuracy.pdf
+   * 
+   * @return double
+   */
+  public double getKappa() {
+    double a = 0.0;
+    double b = 0.0;
+    for(int i = 0; i < confusionMatrix.length; i++) {
+      a += confusionMatrix[i][i];
+      double br = 0;
+      for(int j = 0; j < confusionMatrix.length; j++) {
+        br += confusionMatrix[i][j];
+      }
+      double bc = 0;
+      for(int j = 0; j < confusionMatrix.length; j++) {
+        bc += confusionMatrix[j][i];
+      }
+      b += br * bc;
+    }
+    return (samples * a - b) / (samples * samples - b);
+  }
+  
+  /**
+   * Standard deviation of normalized producer accuracy
+   * Not a standard score
+   * @return double
+   */
+  public RunningAverageAndStdDev getNormalizedStats() {
+    RunningAverageAndStdDev summer = new FullRunningAverageAndStdDev();
+    for(int d = 0; d < confusionMatrix.length; d++) {
+      double total = 0;
+      for(int j = 0; j < confusionMatrix.length; j++) {
+        total += confusionMatrix[d][j];
+      }
+      summer.addDatum(confusionMatrix[d][d] / (total + 0.000001));
+    }
+    
+    return summer;
+  }
+   
   public int getCorrect(String label) {
     int labelId = labelMap.get(label);
     return confusionMatrix[labelId][labelId];
@@ -91,10 +170,12 @@ public class ConfusionMatrix {
   }
   
   public void addInstance(String correctLabel, ClassifierResult classifiedResult) {
+    samples++;
     incrementCount(correctLabel, classifiedResult.getLabel());
   }
   
   public void addInstance(String correctLabel, String classifiedLabel) {
+    samples++;
     incrementCount(correctLabel, classifiedLabel);
   }
   
@@ -111,6 +192,9 @@ public class ConfusionMatrix {
     Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
     int correctId = labelMap.get(correctLabel);
     int classifiedId = labelMap.get(classifiedLabel);
+    if (confusionMatrix[correctId][classifiedId] == 0.0 && count != 0) {
+      samples++;
+    }
     confusionMatrix[correctId][classifiedId] = count;
   }
   

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java?rev=1488595&r1=1488594&r2=1488595&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java Sat Jun  1 20:44:20 2013
@@ -22,6 +22,7 @@ import java.text.NumberFormat;
 import java.util.Collection;
 
 import org.apache.commons.lang3.StringUtils;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
 import org.apache.mahout.math.stats.OnlineSummarizer;
 
 /**
@@ -77,7 +78,8 @@ public class ResultAnalyzer {
   @Override
   public String toString() {
     StringBuilder returnString = new StringBuilder();
-    
+   
+    returnString.append("\n"); 
     returnString.append("=======================================================\n");
     returnString.append("Summary\n");
     returnString.append("-------------------------------------------------------\n");
@@ -97,10 +99,27 @@ public class ResultAnalyzer {
     returnString.append('\n');
     
     returnString.append(confusionMatrix);
+    returnString.append("=======================================================\n");
+    returnString.append("Statistics\n");
+    returnString.append("-------------------------------------------------------\n");
+    
+    RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats();
+    returnString.append(StringUtils.rightPad("Kappa", 40)).append(
+      StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n');
+    returnString.append(StringUtils.rightPad("Accuracy", 40)).append(
+      StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10)).append("%\n");
+    returnString.append(StringUtils.rightPad("Reliability", 40)).append(
+      StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n");
+    returnString.append(StringUtils.rightPad("Reliability (standard deviation)", 40)).append(
+      StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append("\n"); 
+    
     if (hasLL) {
-      returnString.append("\n\n");
-      returnString.append("Avg. Log-likelihood: ").append(summarizer.getMean()).append(" 25%-ile: ")
-          .append(summarizer.getQuartile(1)).append(" 75%-ile: ").append(summarizer.getQuartile(2));
+      returnString.append(StringUtils.rightPad("Log-likelihood", 30)).append("mean      : ").append(
+        StringUtils.leftPad(decimalFormatter.format(summarizer.getMean()), 10)).append('\n');
+      returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("25%-ile   : ", 10)).append(
+        StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(1)), 10)).append('\n');
+      returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("75%-ile   : ", 10)).append(
+        StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(3)), 10)).append('\n');
     }
 
     return returnString.toString();

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java?rev=1488595&r1=1488594&r2=1488595&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java Sat Jun  1 20:44:20 2013
@@ -30,6 +30,7 @@ public final class ConfusionMatrixTest e
 
   private static final int[][] VALUES = {{2, 3}, {10, 20}};
   private static final String[] LABELS = {"Label1", "Label2"};
+  private static final int[] OTHER = {3, 6};
   private static final String DEFAULT_LABEL = "other";
   
   @Test
@@ -63,8 +64,8 @@ public final class ConfusionMatrixTest e
     assertEquals(VALUES[1][0], counts[1][0]);
     assertEquals(VALUES[1][1], counts[1][1]);
     assertTrue(Arrays.equals(new int[3], counts[2])); // zeros
-    assertEquals(0, counts[0][2]);
-    assertEquals(0, counts[1][2]);
+    assertEquals(OTHER[0], counts[0][2]);
+    assertEquals(OTHER[1], counts[1][2]);
     assertEquals(3, cm.getLabels().size());
     assertTrue(cm.getLabels().contains(LABELS[0]));
     assertTrue(cm.getLabels().contains(LABELS[1]));
@@ -75,8 +76,8 @@ public final class ConfusionMatrixTest e
   private static void checkAccuracy(ConfusionMatrix cm) {
     Collection<String> labelstrs = cm.getLabels();
     assertEquals(3, labelstrs.size());
-    assertEquals(40.0, cm.getAccuracy("Label1"), EPSILON);
-    assertEquals(66.666666667, cm.getAccuracy("Label2"), EPSILON);
+    assertEquals(25.0, cm.getAccuracy("Label1"), EPSILON);
+    assertEquals(55.5555555, cm.getAccuracy("Label2"), EPSILON);
     assertTrue(Double.isNaN(cm.getAccuracy("other")));
   }
   
@@ -86,10 +87,12 @@ public final class ConfusionMatrixTest e
     labelList.add(labels[1]);
     ConfusionMatrix cm = new ConfusionMatrix(labelList, defaultLabel);
     int[][] v = cm.getConfusionMatrix();
-    v[0][0] = values[0][0];
-    v[0][1] = values[0][1];
-    v[1][0] = values[1][0];
-    v[1][1] = values[1][1];
+    cm.putCount("Label1", "Label1", values[0][0]);
+    cm.putCount("Label1", "Label2", values[0][1]);
+    cm.putCount("Label2", "Label1", values[1][0]);
+    cm.putCount("Label2", "Label2", values[1][1]);
+    cm.putCount("Label1", DEFAULT_LABEL, OTHER[0]);
+    cm.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
     return cm;
   }