You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2011/01/17 11:24:54 UTC

svn commit: r1059847 - /mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java

Author: srowen
Date: Mon Jan 17 10:24:53 2011
New Revision: 1059847

URL: http://svn.apache.org/viewvc?rev=1059847&view=rev
Log:
MAHOUT-569 Fix a certain NPE because labels/labelsMap go out of sync somehow

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1059847&r1=1059846&r2=1059847&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java Mon Jan 17 10:24:53 2011
@@ -18,7 +18,9 @@
 package org.apache.mahout.classifier;
 
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.LinkedHashMap;
 import java.util.Map;
 
 import org.apache.commons.lang.StringUtils;
@@ -32,17 +34,12 @@ import com.google.common.base.Preconditi
  * See http://en.wikipedia.org/wiki/Confusion_matrix for background
  */
 public class ConfusionMatrix implements Summarizable {
-  
-  private final Collection<String> labels;
-  
-  private final Map<String,Integer> labelMap = new HashMap<String,Integer>();
-  
+
+  private final Map<String,Integer> labelMap = new LinkedHashMap<String,Integer>();
   private final int[][] confusionMatrix;
-  
   private String defaultLabel = "unknown";
   
   public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
-    this.labels = labels;
     confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
     this.defaultLabel = defaultLabel;
     for (String label : labels) {
@@ -56,14 +53,14 @@ public class ConfusionMatrix implements 
   }
   
   public Collection<String> getLabels() {
-    return labels;
+    return Collections.unmodifiableCollection(labelMap.keySet());
   }
   
   public double getAccuracy(String label) {
     int labelId = labelMap.get(label);
     int labelTotal = 0;
     int correct = 0;
-    for (int i = 0; i < labels.size(); i++) {
+    for (int i = 0; i < labelMap.size(); i++) {
       labelTotal += confusionMatrix[labelId][i];
       if (i == labelId) {
         correct = confusionMatrix[labelId][i];
@@ -80,7 +77,7 @@ public class ConfusionMatrix implements 
   public double getTotal(String label) {
     int labelId = labelMap.get(label);
     int labelTotal = 0;
-    for (int i = 0; i < labels.size(); i++) {
+    for (int i = 0; i < labelMap.size(); i++) {
       labelTotal += confusionMatrix[labelId][i];
     }
     return labelTotal;
@@ -95,8 +92,8 @@ public class ConfusionMatrix implements 
   }
   
   public int getCount(String correctLabel, String classifiedLabel) {
-    Preconditions.checkArgument(!labels.contains(correctLabel)
-        || labels.contains(classifiedLabel)
+    Preconditions.checkArgument(!labelMap.containsKey(correctLabel)
+        || labelMap.containsKey(classifiedLabel)
         || defaultLabel.equals(classifiedLabel),
         "Label not found " + correctLabel + ' ' + classifiedLabel);
     int correctId = labelMap.get(correctLabel);
@@ -105,8 +102,8 @@ public class ConfusionMatrix implements 
   }
   
   public void putCount(String correctLabel, String classifiedLabel, int count) {
-    Preconditions.checkArgument(!labels.contains(correctLabel)
-        || labels.contains(classifiedLabel)
+    Preconditions.checkArgument(!labelMap.containsKey(correctLabel)
+        || labelMap.containsKey(classifiedLabel)
         || defaultLabel.equals(classifiedLabel),
         "Label not found " + correctLabel + ' ' + classifiedLabel);
     int correctId = labelMap.get(correctLabel);
@@ -123,9 +120,9 @@ public class ConfusionMatrix implements 
   }
   
   public ConfusionMatrix merge(ConfusionMatrix b) {
-    Preconditions.checkArgument(labels.size() == b.getLabels().size(), "The label sizes do not match");
-    for (String correctLabel : this.labels) {
-      for (String classifiedLabel : this.labels) {
+    Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match");
+    for (String correctLabel : this.labelMap.keySet()) {
+      for (String classifiedLabel : this.labelMap.keySet()) {
         incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel));
       }
     }
@@ -139,16 +136,15 @@ public class ConfusionMatrix implements 
     returnString.append("Confusion Matrix\n");
     returnString.append("-------------------------------------------------------").append('\n');
     
-    for (String correctLabel : this.labels) {
-      returnString.append(StringUtils.rightPad(getSmallLabel(labelMap.get(correctLabel)), 5))
-          .append('\t');
+    for (String correctLabel : this.labelMap.keySet()) {
+      returnString.append(StringUtils.rightPad(getSmallLabel(labelMap.get(correctLabel)), 5)).append('\t');
     }
     
     returnString.append("<--Classified as").append('\n');
     
-    for (String correctLabel : this.labels) {
+    for (String correctLabel : this.labelMap.keySet()) {
       int labelTotal = 0;
-      for (String classifiedLabel : this.labels) {
+      for (String classifiedLabel : this.labelMap.keySet()) {
         returnString.append(
           StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t');
         labelTotal += getCount(correctLabel, classifiedLabel);