You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2013/06/01 22:44:20 UTC
svn commit: r1488595 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/classifier/ConfusionMatrix.java
main/java/org/apache/mahout/classifier/ResultAnalyzer.java
test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
Author: robinanil
Date: Sat Jun 1 20:44:20 2013
New Revision: 1488595
URL: http://svn.apache.org/r1488595
Log:
MAHOUT-941 new confusion matrix statistics
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1488595&r1=1488594&r2=1488595&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java Sat Jun 1 20:44:20 2013
@@ -22,8 +22,11 @@ import java.util.Collections;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.stats.OnlineSummarizer;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
@@ -38,6 +41,7 @@ import com.google.common.collect.Maps;
public class ConfusionMatrix {
private final Map<String,Integer> labelMap = Maps.newLinkedHashMap();
private final int[][] confusionMatrix;
+ private int samples = 0;
private String defaultLabel = "unknown";
public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
@@ -70,12 +74,87 @@ public class ConfusionMatrix {
for (int i = 0; i < labelMap.size(); i++) {
labelTotal += confusionMatrix[labelId][i];
if (i == labelId) {
- correct = confusionMatrix[labelId][i];
+ correct += confusionMatrix[labelId][i];
}
}
return 100.0 * correct / labelTotal;
}
+
+ // Producer accuracy
+ public double getAccuracy() {
+ int total = 0;
+ int correct = 0;
+ for (int i = 0; i < labelMap.size(); i++) {
+ for (int j = 0; j < labelMap.size(); j++) {
+ total += confusionMatrix[i][j];
+ if (i == j) {
+ correct += confusionMatrix[i][j];
+ }
+ }
+ }
+ return 100.0 * correct / total;
+ }
+ // User accuracy
+ public double getReliability() {
+ int count = 0;
+ double accuracy = 0;
+ for (String label: labelMap.keySet()) {
+ if (! label.equals(defaultLabel)) {
+ accuracy += getAccuracy(label);
+ }
+ count++;
+ }
+ return accuracy / count;
+ }
+
+ /**
+ * Accuracy v.s. randomly classifying all samples.
+ * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy())
+ * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales.
+ * Educational And Psychological Measurement 20:37-46.
+ *
+ * Formula and variable names from:
+ * http://www.yale.edu/ceo/OEFS/Accuracy.pdf
+ *
+ * @return double
+ */
+ public double getKappa() {
+ double a = 0.0;
+ double b = 0.0;
+ for(int i = 0; i < confusionMatrix.length; i++) {
+ a += confusionMatrix[i][i];
+ double br = 0;
+ for(int j = 0; j < confusionMatrix.length; j++) {
+ br += confusionMatrix[i][j];
+ }
+ double bc = 0;
+ for(int j = 0; j < confusionMatrix.length; j++) {
+ bc += confusionMatrix[j][i];
+ }
+ b += br * bc;
+ }
+ return (samples * a - b) / (samples * samples - b);
+ }
+
+ /**
+ * Standard deviation of normalized producer accuracy
+ * Not a standard score
+ * @return double
+ */
+ public RunningAverageAndStdDev getNormalizedStats() {
+ RunningAverageAndStdDev summer = new FullRunningAverageAndStdDev();
+ for(int d = 0; d < confusionMatrix.length; d++) {
+ double total = 0;
+ for(int j = 0; j < confusionMatrix.length; j++) {
+ total += confusionMatrix[d][j];
+ }
+ summer.addDatum(confusionMatrix[d][d] / (total + 0.000001));
+ }
+
+ return summer;
+ }
+
public int getCorrect(String label) {
int labelId = labelMap.get(label);
return confusionMatrix[labelId][labelId];
@@ -91,10 +170,12 @@ public class ConfusionMatrix {
}
public void addInstance(String correctLabel, ClassifierResult classifiedResult) {
+ samples++;
incrementCount(correctLabel, classifiedResult.getLabel());
}
public void addInstance(String correctLabel, String classifiedLabel) {
+ samples++;
incrementCount(correctLabel, classifiedLabel);
}
@@ -111,6 +192,9 @@ public class ConfusionMatrix {
Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
int correctId = labelMap.get(correctLabel);
int classifiedId = labelMap.get(classifiedLabel);
+ if (confusionMatrix[correctId][classifiedId] == 0.0 && count != 0) {
+ samples++;
+ }
confusionMatrix[correctId][classifiedId] = count;
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java?rev=1488595&r1=1488594&r2=1488595&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java Sat Jun 1 20:44:20 2013
@@ -22,6 +22,7 @@ import java.text.NumberFormat;
import java.util.Collection;
import org.apache.commons.lang3.StringUtils;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.math.stats.OnlineSummarizer;
/**
@@ -77,7 +78,8 @@ public class ResultAnalyzer {
@Override
public String toString() {
StringBuilder returnString = new StringBuilder();
-
+
+ returnString.append("\n");
returnString.append("=======================================================\n");
returnString.append("Summary\n");
returnString.append("-------------------------------------------------------\n");
@@ -97,10 +99,27 @@ public class ResultAnalyzer {
returnString.append('\n');
returnString.append(confusionMatrix);
+ returnString.append("=======================================================\n");
+ returnString.append("Statistics\n");
+ returnString.append("-------------------------------------------------------\n");
+
+ RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats();
+ returnString.append(StringUtils.rightPad("Kappa", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Accuracy", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Reliability", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Reliability (standard deviation)", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append("\n");
+
if (hasLL) {
- returnString.append("\n\n");
- returnString.append("Avg. Log-likelihood: ").append(summarizer.getMean()).append(" 25%-ile: ")
- .append(summarizer.getQuartile(1)).append(" 75%-ile: ").append(summarizer.getQuartile(2));
+ returnString.append(StringUtils.rightPad("Log-likelihood", 30)).append("mean : ").append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getMean()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("25%-ile : ", 10)).append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(1)), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("75%-ile : ", 10)).append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(3)), 10)).append('\n');
}
return returnString.toString();
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java?rev=1488595&r1=1488594&r2=1488595&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java Sat Jun 1 20:44:20 2013
@@ -30,6 +30,7 @@ public final class ConfusionMatrixTest e
private static final int[][] VALUES = {{2, 3}, {10, 20}};
private static final String[] LABELS = {"Label1", "Label2"};
+ private static final int[] OTHER = {3, 6};
private static final String DEFAULT_LABEL = "other";
@Test
@@ -63,8 +64,8 @@ public final class ConfusionMatrixTest e
assertEquals(VALUES[1][0], counts[1][0]);
assertEquals(VALUES[1][1], counts[1][1]);
assertTrue(Arrays.equals(new int[3], counts[2])); // zeros
- assertEquals(0, counts[0][2]);
- assertEquals(0, counts[1][2]);
+ assertEquals(OTHER[0], counts[0][2]);
+ assertEquals(OTHER[1], counts[1][2]);
assertEquals(3, cm.getLabels().size());
assertTrue(cm.getLabels().contains(LABELS[0]));
assertTrue(cm.getLabels().contains(LABELS[1]));
@@ -75,8 +76,8 @@ public final class ConfusionMatrixTest e
private static void checkAccuracy(ConfusionMatrix cm) {
Collection<String> labelstrs = cm.getLabels();
assertEquals(3, labelstrs.size());
- assertEquals(40.0, cm.getAccuracy("Label1"), EPSILON);
- assertEquals(66.666666667, cm.getAccuracy("Label2"), EPSILON);
+ assertEquals(25.0, cm.getAccuracy("Label1"), EPSILON);
+ assertEquals(55.5555555, cm.getAccuracy("Label2"), EPSILON);
assertTrue(Double.isNaN(cm.getAccuracy("other")));
}
@@ -86,10 +87,12 @@ public final class ConfusionMatrixTest e
labelList.add(labels[1]);
ConfusionMatrix cm = new ConfusionMatrix(labelList, defaultLabel);
int[][] v = cm.getConfusionMatrix();
- v[0][0] = values[0][0];
- v[0][1] = values[0][1];
- v[1][0] = values[1][0];
- v[1][1] = values[1][1];
+ cm.putCount("Label1", "Label1", values[0][0]);
+ cm.putCount("Label1", "Label2", values[0][1]);
+ cm.putCount("Label2", "Label1", values[1][0]);
+ cm.putCount("Label2", "Label2", values[1][1]);
+ cm.putCount("Label1", DEFAULT_LABEL, OTHER[0]);
+ cm.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
return cm;
}