You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2011/01/17 11:24:54 UTC
svn commit: r1059847 -
/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
Author: srowen
Date: Mon Jan 17 10:24:53 2011
New Revision: 1059847
URL: http://svn.apache.org/viewvc?rev=1059847&view=rev
Log:
MAHOUT-569 Fix a certain NPE because labels/labelsMap go out of sync somehow
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1059847&r1=1059846&r2=1059847&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java Mon Jan 17 10:24:53 2011
@@ -18,7 +18,9 @@
package org.apache.mahout.classifier;
import java.util.Collection;
+import java.util.Collections;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
@@ -32,17 +34,12 @@ import com.google.common.base.Preconditi
* See http://en.wikipedia.org/wiki/Confusion_matrix for background
*/
public class ConfusionMatrix implements Summarizable {
-
- private final Collection<String> labels;
-
- private final Map<String,Integer> labelMap = new HashMap<String,Integer>();
-
+
+ private final Map<String,Integer> labelMap = new LinkedHashMap<String,Integer>();
private final int[][] confusionMatrix;
-
private String defaultLabel = "unknown";
public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
- this.labels = labels;
confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
this.defaultLabel = defaultLabel;
for (String label : labels) {
@@ -56,14 +53,14 @@ public class ConfusionMatrix implements
}
public Collection<String> getLabels() {
- return labels;
+ return Collections.unmodifiableCollection(labelMap.keySet());
}
public double getAccuracy(String label) {
int labelId = labelMap.get(label);
int labelTotal = 0;
int correct = 0;
- for (int i = 0; i < labels.size(); i++) {
+ for (int i = 0; i < labelMap.size(); i++) {
labelTotal += confusionMatrix[labelId][i];
if (i == labelId) {
correct = confusionMatrix[labelId][i];
@@ -80,7 +77,7 @@ public class ConfusionMatrix implements
public double getTotal(String label) {
int labelId = labelMap.get(label);
int labelTotal = 0;
- for (int i = 0; i < labels.size(); i++) {
+ for (int i = 0; i < labelMap.size(); i++) {
labelTotal += confusionMatrix[labelId][i];
}
return labelTotal;
@@ -95,8 +92,8 @@ public class ConfusionMatrix implements
}
public int getCount(String correctLabel, String classifiedLabel) {
- Preconditions.checkArgument(!labels.contains(correctLabel)
- || labels.contains(classifiedLabel)
+ Preconditions.checkArgument(!labelMap.containsKey(correctLabel)
+ || labelMap.containsKey(classifiedLabel)
|| defaultLabel.equals(classifiedLabel),
"Label not found " + correctLabel + ' ' + classifiedLabel);
int correctId = labelMap.get(correctLabel);
@@ -105,8 +102,8 @@ public class ConfusionMatrix implements
}
public void putCount(String correctLabel, String classifiedLabel, int count) {
- Preconditions.checkArgument(!labels.contains(correctLabel)
- || labels.contains(classifiedLabel)
+ Preconditions.checkArgument(!labelMap.containsKey(correctLabel)
+ || labelMap.containsKey(classifiedLabel)
|| defaultLabel.equals(classifiedLabel),
"Label not found " + correctLabel + ' ' + classifiedLabel);
int correctId = labelMap.get(correctLabel);
@@ -123,9 +120,9 @@ public class ConfusionMatrix implements
}
public ConfusionMatrix merge(ConfusionMatrix b) {
- Preconditions.checkArgument(labels.size() == b.getLabels().size(), "The label sizes do not match");
- for (String correctLabel : this.labels) {
- for (String classifiedLabel : this.labels) {
+ Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match");
+ for (String correctLabel : this.labelMap.keySet()) {
+ for (String classifiedLabel : this.labelMap.keySet()) {
incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel));
}
}
@@ -139,16 +136,15 @@ public class ConfusionMatrix implements
returnString.append("Confusion Matrix\n");
returnString.append("-------------------------------------------------------").append('\n');
- for (String correctLabel : this.labels) {
- returnString.append(StringUtils.rightPad(getSmallLabel(labelMap.get(correctLabel)), 5))
- .append('\t');
+ for (String correctLabel : this.labelMap.keySet()) {
+ returnString.append(StringUtils.rightPad(getSmallLabel(labelMap.get(correctLabel)), 5)).append('\t');
}
returnString.append("<--Classified as").append('\n');
- for (String correctLabel : this.labels) {
+ for (String correctLabel : this.labelMap.keySet()) {
int labelTotal = 0;
- for (String classifiedLabel : this.labels) {
+ for (String classifiedLabel : this.labelMap.keySet()) {
returnString.append(
StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t');
labelTotal += getCount(correctLabel, classifiedLabel);