You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ss...@apache.org on 2014/05/23 18:39:52 UTC

git commit: MAHOUT-1554 Provide more comprehensive classification statistics

Repository: mahout
Updated Branches:
  refs/heads/master d850a091d -> 299fe6cc2


MAHOUT-1554 Provide more comprehensive classification statistics


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/299fe6cc
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/299fe6cc
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/299fe6cc

Branch: refs/heads/master
Commit: 299fe6cc217f3d62057a55140a44ada0ce6ea145
Parents: d850a09
Author: ssc <ss...@apache.org>
Authored: Fri May 23 18:39:18 2014 +0200
Committer: ssc <ss...@apache.org>
Committed: Fri May 23 18:39:18 2014 +0200

----------------------------------------------------------------------
 CHANGELOG                                       |   2 +
 .../mahout/classifier/ConfusionMatrix.java      | 107 ++++++++++++++++++-
 .../mahout/classifier/ResultAnalyzer.java       |  13 ++-
 .../mahout/classifier/ConfusionMatrixTest.java  |  68 +++++++-----
 4 files changed, 157 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/299fe6cc/CHANGELOG
----------------------------------------------------------------------
diff --git a/CHANGELOG b/CHANGELOG
index 5e9c2f2..2166426 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -2,6 +2,8 @@ Mahout Change Log
 
 Release 1.0 - unreleased
 
+  MAHOUT-1554: Provide more comprehensive classification statistics (Karol Grzegorczyk via ssc)
+
   MAHOUT-1548: Fix broken links in quickstart webpage (Andrew Palumbo via ssc)
 
   MAHOUT-1542: Tutorial for playing with Mahout's Spark shell (ssc)

http://git-wip-us.apache.org/repos/asf/mahout/blob/299fe6cc/mrlegacy/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
----------------------------------------------------------------------
diff --git a/mrlegacy/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java b/mrlegacy/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
index d08023d..b8d96d7 100644
--- a/mrlegacy/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
+++ b/mrlegacy/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
@@ -22,6 +22,7 @@ import java.util.Collections;
 import java.util.Map;
 
 import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.math3.stat.descriptive.moment.Mean;
 import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
 import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
 import org.apache.mahout.math.DenseMatrix;
@@ -65,12 +66,16 @@ public class ConfusionMatrix {
   public Collection<String> getLabels() {
     return Collections.unmodifiableCollection(labelMap.keySet());
   }
-  
+
+  private int numLabels() {
+    return labelMap.size();
+  }
+
   public double getAccuracy(String label) {
     int labelId = labelMap.get(label);
     int labelTotal = 0;
     int correct = 0;
-    for (int i = 0; i < labelMap.size(); i++) {
+    for (int i = 0; i < numLabels(); i++) {
       labelTotal += confusionMatrix[labelId][i];
       if (i == labelId) {
         correct += confusionMatrix[labelId][i];
@@ -83,8 +88,8 @@ public class ConfusionMatrix {
   public double getAccuracy() {
     int total = 0;
     int correct = 0;
-    for (int i = 0; i < labelMap.size(); i++) {
-      for (int j = 0; j < labelMap.size(); j++) {
+    for (int i = 0; i < numLabels(); i++) {
+      for (int j = 0; j < numLabels(); j++) {
         total += confusionMatrix[i][j];
         if (i == j) {
           correct += confusionMatrix[i][j];
@@ -93,7 +98,99 @@ public class ConfusionMatrix {
     }
     return 100.0 * correct / total;
   }
-  
+
+  /** Sum of true positives and false negatives */
+  private int getActualNumberOfTestExamplesForClass(String label) {
+    int labelId = labelMap.get(label);
+    int sum = 0;
+    for (int i = 0; i < numLabels(); i++) {
+      sum += confusionMatrix[labelId][i];
+    }
+    return sum;
+  }
+
+  public double getPrecision(String label) {
+    int labelId = labelMap.get(label);
+    int truePositives = confusionMatrix[labelId][labelId];
+    int falsePositives = 0;
+    for (int i = 0; i < numLabels(); i++) {
+      if (i == labelId) {
+        continue;
+      }
+      falsePositives += confusionMatrix[i][labelId];
+    }
+
+    if (truePositives + falsePositives == 0) {
+      return 0;
+    }
+
+    return ((double) truePositives) / (truePositives + falsePositives);
+  }
+
+  public double getWeightedPrecision() {
+    double[] precisions = new double[numLabels()];
+    double[] weights = new double[numLabels()];
+
+    int index = 0;
+    for (String label : labelMap.keySet()) {
+      precisions[index] = getPrecision(label);
+      weights[index] = getActualNumberOfTestExamplesForClass(label);
+      index++;
+    }
+    return new Mean().evaluate(precisions, weights);
+  }
+
+  public double getRecall(String label) {
+    int labelId = labelMap.get(label);
+    int truePositives = confusionMatrix[labelId][labelId];
+    int falseNegatives = 0;
+    for (int i = 0; i < numLabels(); i++) {
+      if (i == labelId) {
+        continue;
+      }
+      falseNegatives += confusionMatrix[labelId][i];
+    }
+    if (truePositives + falseNegatives == 0) {
+      return 0;
+    }
+    return ((double) truePositives) / (truePositives + falseNegatives);
+  }
+
+  public double getWeightedRecall() {
+    double[] recalls = new double[numLabels()];
+    double[] weights = new double[numLabels()];
+
+    int index = 0;
+    for (String label : labelMap.keySet()) {
+      recalls[index] = getRecall(label);
+      weights[index] = getActualNumberOfTestExamplesForClass(label);
+      index++;
+    }
+    return new Mean().evaluate(recalls, weights);
+  }
+
+  public double getF1score(String label) {
+    double precision = getPrecision(label);
+    double recall = getRecall(label);
+    if (precision + recall == 0) {
+      return 0;
+    }
+    return 2 * precision * recall / (precision + recall);
+  }
+
+  public double getWeightedF1score() {
+    double[] f1Scores = new double[numLabels()];
+    double[] weights = new double[numLabels()];
+
+    int index = 0;
+    for (String label : labelMap.keySet()) {
+      f1Scores[index] = getF1score(label);
+      weights[index] = getActualNumberOfTestExamplesForClass(label);
+      index++;
+    }
+    return new Mean().evaluate(f1Scores, weights);
+  }
+
   // User accuracy 
   public double getReliability() {
     int count = 0;

http://git-wip-us.apache.org/repos/asf/mahout/blob/299fe6cc/mrlegacy/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
----------------------------------------------------------------------
diff --git a/mrlegacy/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java b/mrlegacy/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
index 64e3e5b..1711f19 100644
--- a/mrlegacy/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
+++ b/mrlegacy/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
@@ -25,14 +25,13 @@ import org.apache.commons.lang3.StringUtils;
 import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
 import org.apache.mahout.math.stats.OnlineSummarizer;
 
-/**
- * ResultAnalyzer captures the classification statistics and displays in a tabular manner
- */
+/** ResultAnalyzer captures the classification statistics and displays in a tabular manner */
 public class ResultAnalyzer {
   
   private final ConfusionMatrix confusionMatrix;
   private final OnlineSummarizer summarizer;
   private boolean hasLL;
+
   /*
    * === Summary ===
    * 
@@ -111,7 +110,13 @@ public class ResultAnalyzer {
     returnString.append(StringUtils.rightPad("Reliability", 40)).append(
       StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n");
     returnString.append(StringUtils.rightPad("Reliability (standard deviation)", 40)).append(
-      StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n'); 
+      StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n');
+    returnString.append(StringUtils.rightPad("Weighted precision", 40)).append(
+      StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10)).append('\n');
+    returnString.append(StringUtils.rightPad("Weighted recall", 40)).append(
+      StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10)).append('\n');
+    returnString.append(StringUtils.rightPad("Weighted F1 score", 40)).append(
+      StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10)).append('\n');
     
     if (hasLL) {
       returnString.append(StringUtils.rightPad("Log-likelihood", 30)).append("mean      : ").append(

http://git-wip-us.apache.org/repos/asf/mahout/blob/299fe6cc/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
----------------------------------------------------------------------
diff --git a/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java b/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
index 8fa7c97..ebbed92 100644
--- a/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
+++ b/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
@@ -35,25 +35,46 @@ public final class ConfusionMatrixTest extends MahoutTestCase {
   
   @Test
   public void testBuild() {
-    ConfusionMatrix cm = fillCM(VALUES, LABELS, DEFAULT_LABEL);
-    checkValues(cm);
-    checkAccuracy(cm);
+    ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
+    checkValues(confusionMatrix);
+    checkAccuracy(confusionMatrix);
   }
 
   @Test
   public void testGetMatrix() {
-	    ConfusionMatrix cm = fillCM(VALUES, LABELS, DEFAULT_LABEL);
-	    Matrix m = cm.getMatrix();
-	    Map<String, Integer> rowLabels = m.getRowLabelBindings();
-	    assertEquals(cm.getLabels().size(), m.numCols());
-	    assertTrue(rowLabels.keySet().contains(LABELS[0]));
-	    assertTrue(rowLabels.keySet().contains(LABELS[1]));
-	    assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL));
-	    assertEquals(2, cm.getCorrect(LABELS[0]));
-	    assertEquals(20, cm.getCorrect(LABELS[1]));
-	    assertEquals(0, cm.getCorrect(DEFAULT_LABEL));
+    ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
+    Matrix m = confusionMatrix.getMatrix();
+    Map<String, Integer> rowLabels = m.getRowLabelBindings();
+
+    assertEquals(confusionMatrix.getLabels().size(), m.numCols());
+    assertTrue(rowLabels.keySet().contains(LABELS[0]));
+    assertTrue(rowLabels.keySet().contains(LABELS[1]));
+    assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL));
+    assertEquals(2, confusionMatrix.getCorrect(LABELS[0]));
+    assertEquals(20, confusionMatrix.getCorrect(LABELS[1]));
+    assertEquals(0, confusionMatrix.getCorrect(DEFAULT_LABEL));
   }
 
+    /**
+     * Example taken from
+     * http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
+     */
+    @Test
+    public void testPrecisionRecallAndF1ScoreAsScikitLearn() {
+      Collection<String> labelList = Arrays.asList("0", "1", "2");
+
+      ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, "DEFAULT");
+      confusionMatrix.putCount("0", "0", 2);
+      confusionMatrix.putCount("1", "0", 1);
+      confusionMatrix.putCount("1", "2", 1);
+      confusionMatrix.putCount("2", "1", 2);
+
+      double delta = 0.001;
+      assertEquals(0.222, confusionMatrix.getWeightedPrecision(), delta);
+      assertEquals(0.333, confusionMatrix.getWeightedRecall(), delta);
+      assertEquals(0.266, confusionMatrix.getWeightedF1score(), delta);
+    }
+
   private static void checkValues(ConfusionMatrix cm) {
     int[][] counts = cm.getConfusionMatrix();
     cm.toString();
@@ -70,7 +91,6 @@ public final class ConfusionMatrixTest extends MahoutTestCase {
     assertTrue(cm.getLabels().contains(LABELS[0]));
     assertTrue(cm.getLabels().contains(LABELS[1]));
     assertTrue(cm.getLabels().contains(DEFAULT_LABEL));
-
   }
 
   private static void checkAccuracy(ConfusionMatrix cm) {
@@ -81,19 +101,19 @@ public final class ConfusionMatrixTest extends MahoutTestCase {
     assertTrue(Double.isNaN(cm.getAccuracy("other")));
   }
   
-  private static ConfusionMatrix fillCM(int[][] values, String[] labels, String defaultLabel) {
+  private static ConfusionMatrix fillConfusionMatrix(int[][] values, String[] labels, String defaultLabel) {
     Collection<String> labelList = Lists.newArrayList();
     labelList.add(labels[0]);
     labelList.add(labels[1]);
-    ConfusionMatrix cm = new ConfusionMatrix(labelList, defaultLabel);
-    //int[][] v = cm.getConfusionMatrix();
-    cm.putCount("Label1", "Label1", values[0][0]);
-    cm.putCount("Label1", "Label2", values[0][1]);
-    cm.putCount("Label2", "Label1", values[1][0]);
-    cm.putCount("Label2", "Label2", values[1][1]);
-    cm.putCount("Label1", DEFAULT_LABEL, OTHER[0]);
-    cm.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
-    return cm;
+    ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, defaultLabel);
+
+    confusionMatrix.putCount("Label1", "Label1", values[0][0]);
+    confusionMatrix.putCount("Label1", "Label2", values[0][1]);
+    confusionMatrix.putCount("Label2", "Label1", values[1][0]);
+    confusionMatrix.putCount("Label2", "Label2", values[1][1]);
+    confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER[0]);
+    confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
+    return confusionMatrix;
   }
   
 }