You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ss...@apache.org on 2014/05/23 18:39:52 UTC
git commit: MAHOUT-1554 Provide more comprehensive classification
statistics
Repository: mahout
Updated Branches:
refs/heads/master d850a091d -> 299fe6cc2
MAHOUT-1554 Provide more comprehensive classification statistics
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/299fe6cc
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/299fe6cc
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/299fe6cc
Branch: refs/heads/master
Commit: 299fe6cc217f3d62057a55140a44ada0ce6ea145
Parents: d850a09
Author: ssc <ss...@apache.org>
Authored: Fri May 23 18:39:18 2014 +0200
Committer: ssc <ss...@apache.org>
Committed: Fri May 23 18:39:18 2014 +0200
----------------------------------------------------------------------
CHANGELOG | 2 +
.../mahout/classifier/ConfusionMatrix.java | 107 ++++++++++++++++++-
.../mahout/classifier/ResultAnalyzer.java | 13 ++-
.../mahout/classifier/ConfusionMatrixTest.java | 68 +++++++-----
4 files changed, 157 insertions(+), 33 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/299fe6cc/CHANGELOG
----------------------------------------------------------------------
diff --git a/CHANGELOG b/CHANGELOG
index 5e9c2f2..2166426 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -2,6 +2,8 @@ Mahout Change Log
Release 1.0 - unreleased
+ MAHOUT-1554: Provide more comprehensive classification statistics (Karol Grzegorczyk via ssc)
+
MAHOUT-1548: Fix broken links in quickstart webpage (Andrew Palumbo via ssc)
MAHOUT-1542: Tutorial for playing with Mahout's Spark shell (ssc)
http://git-wip-us.apache.org/repos/asf/mahout/blob/299fe6cc/mrlegacy/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
----------------------------------------------------------------------
diff --git a/mrlegacy/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java b/mrlegacy/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
index d08023d..b8d96d7 100644
--- a/mrlegacy/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
+++ b/mrlegacy/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
@@ -22,6 +22,7 @@ import java.util.Collections;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.math.DenseMatrix;
@@ -65,12 +66,16 @@ public class ConfusionMatrix {
public Collection<String> getLabels() {
return Collections.unmodifiableCollection(labelMap.keySet());
}
-
+
+ private int numLabels() {
+ return labelMap.size();
+ }
+
public double getAccuracy(String label) {
int labelId = labelMap.get(label);
int labelTotal = 0;
int correct = 0;
- for (int i = 0; i < labelMap.size(); i++) {
+ for (int i = 0; i < numLabels(); i++) {
labelTotal += confusionMatrix[labelId][i];
if (i == labelId) {
correct += confusionMatrix[labelId][i];
@@ -83,8 +88,8 @@ public class ConfusionMatrix {
public double getAccuracy() {
int total = 0;
int correct = 0;
- for (int i = 0; i < labelMap.size(); i++) {
- for (int j = 0; j < labelMap.size(); j++) {
+ for (int i = 0; i < numLabels(); i++) {
+ for (int j = 0; j < numLabels(); j++) {
total += confusionMatrix[i][j];
if (i == j) {
correct += confusionMatrix[i][j];
@@ -93,7 +98,99 @@ public class ConfusionMatrix {
}
return 100.0 * correct / total;
}
-
+
+ /** Sum of true positives and false negatives */
+ private int getActualNumberOfTestExamplesForClass(String label) {
+ int labelId = labelMap.get(label);
+ int sum = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ sum += confusionMatrix[labelId][i];
+ }
+ return sum;
+ }
+
+ public double getPrecision(String label) {
+ int labelId = labelMap.get(label);
+ int truePositives = confusionMatrix[labelId][labelId];
+ int falsePositives = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ if (i == labelId) {
+ continue;
+ }
+ falsePositives += confusionMatrix[i][labelId];
+ }
+
+ if (truePositives + falsePositives == 0) {
+ return 0;
+ }
+
+ return ((double) truePositives) / (truePositives + falsePositives);
+ }
+
+ public double getWeightedPrecision() {
+ double[] precisions = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ precisions[index] = getPrecision(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(precisions, weights);
+ }
+
+ public double getRecall(String label) {
+ int labelId = labelMap.get(label);
+ int truePositives = confusionMatrix[labelId][labelId];
+ int falseNegatives = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ if (i == labelId) {
+ continue;
+ }
+ falseNegatives += confusionMatrix[labelId][i];
+ }
+ if (truePositives + falseNegatives == 0) {
+ return 0;
+ }
+ return ((double) truePositives) / (truePositives + falseNegatives);
+ }
+
+ public double getWeightedRecall() {
+ double[] recalls = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ recalls[index] = getRecall(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(recalls, weights);
+ }
+
+ public double getF1score(String label) {
+ double precision = getPrecision(label);
+ double recall = getRecall(label);
+ if (precision + recall == 0) {
+ return 0;
+ }
+ return 2 * precision * recall / (precision + recall);
+ }
+
+ public double getWeightedF1score() {
+ double[] f1Scores = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ f1Scores[index] = getF1score(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(f1Scores, weights);
+ }
+
// User accuracy
public double getReliability() {
int count = 0;
http://git-wip-us.apache.org/repos/asf/mahout/blob/299fe6cc/mrlegacy/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
----------------------------------------------------------------------
diff --git a/mrlegacy/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java b/mrlegacy/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
index 64e3e5b..1711f19 100644
--- a/mrlegacy/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
+++ b/mrlegacy/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
@@ -25,14 +25,13 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.math.stats.OnlineSummarizer;
-/**
- * ResultAnalyzer captures the classification statistics and displays in a tabular manner
- */
+/** ResultAnalyzer captures the classification statistics and displays in a tabular manner */
public class ResultAnalyzer {
private final ConfusionMatrix confusionMatrix;
private final OnlineSummarizer summarizer;
private boolean hasLL;
+
/*
* === Summary ===
*
@@ -111,7 +110,13 @@ public class ResultAnalyzer {
returnString.append(StringUtils.rightPad("Reliability", 40)).append(
StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n");
returnString.append(StringUtils.rightPad("Reliability (standard deviation)", 40)).append(
- StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n');
+ StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted precision", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted recall", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted F1 score", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10)).append('\n');
if (hasLL) {
returnString.append(StringUtils.rightPad("Log-likelihood", 30)).append("mean : ").append(
http://git-wip-us.apache.org/repos/asf/mahout/blob/299fe6cc/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
----------------------------------------------------------------------
diff --git a/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java b/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
index 8fa7c97..ebbed92 100644
--- a/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
+++ b/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
@@ -35,25 +35,46 @@ public final class ConfusionMatrixTest extends MahoutTestCase {
@Test
public void testBuild() {
- ConfusionMatrix cm = fillCM(VALUES, LABELS, DEFAULT_LABEL);
- checkValues(cm);
- checkAccuracy(cm);
+ ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
+ checkValues(confusionMatrix);
+ checkAccuracy(confusionMatrix);
}
@Test
public void testGetMatrix() {
- ConfusionMatrix cm = fillCM(VALUES, LABELS, DEFAULT_LABEL);
- Matrix m = cm.getMatrix();
- Map<String, Integer> rowLabels = m.getRowLabelBindings();
- assertEquals(cm.getLabels().size(), m.numCols());
- assertTrue(rowLabels.keySet().contains(LABELS[0]));
- assertTrue(rowLabels.keySet().contains(LABELS[1]));
- assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL));
- assertEquals(2, cm.getCorrect(LABELS[0]));
- assertEquals(20, cm.getCorrect(LABELS[1]));
- assertEquals(0, cm.getCorrect(DEFAULT_LABEL));
+ ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
+ Matrix m = confusionMatrix.getMatrix();
+ Map<String, Integer> rowLabels = m.getRowLabelBindings();
+
+ assertEquals(confusionMatrix.getLabels().size(), m.numCols());
+ assertTrue(rowLabels.keySet().contains(LABELS[0]));
+ assertTrue(rowLabels.keySet().contains(LABELS[1]));
+ assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL));
+ assertEquals(2, confusionMatrix.getCorrect(LABELS[0]));
+ assertEquals(20, confusionMatrix.getCorrect(LABELS[1]));
+ assertEquals(0, confusionMatrix.getCorrect(DEFAULT_LABEL));
}
+ /**
+ * Example taken from
+ * http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
+ */
+ @Test
+ public void testPrecisionRecallAndF1ScoreAsScikitLearn() {
+ Collection<String> labelList = Arrays.asList("0", "1", "2");
+
+ ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, "DEFAULT");
+ confusionMatrix.putCount("0", "0", 2);
+ confusionMatrix.putCount("1", "0", 1);
+ confusionMatrix.putCount("1", "2", 1);
+ confusionMatrix.putCount("2", "1", 2);
+
+ double delta = 0.001;
+ assertEquals(0.222, confusionMatrix.getWeightedPrecision(), delta);
+ assertEquals(0.333, confusionMatrix.getWeightedRecall(), delta);
+ assertEquals(0.266, confusionMatrix.getWeightedF1score(), delta);
+ }
+
private static void checkValues(ConfusionMatrix cm) {
int[][] counts = cm.getConfusionMatrix();
cm.toString();
@@ -70,7 +91,6 @@ public final class ConfusionMatrixTest extends MahoutTestCase {
assertTrue(cm.getLabels().contains(LABELS[0]));
assertTrue(cm.getLabels().contains(LABELS[1]));
assertTrue(cm.getLabels().contains(DEFAULT_LABEL));
-
}
private static void checkAccuracy(ConfusionMatrix cm) {
@@ -81,19 +101,19 @@ public final class ConfusionMatrixTest extends MahoutTestCase {
assertTrue(Double.isNaN(cm.getAccuracy("other")));
}
- private static ConfusionMatrix fillCM(int[][] values, String[] labels, String defaultLabel) {
+ private static ConfusionMatrix fillConfusionMatrix(int[][] values, String[] labels, String defaultLabel) {
Collection<String> labelList = Lists.newArrayList();
labelList.add(labels[0]);
labelList.add(labels[1]);
- ConfusionMatrix cm = new ConfusionMatrix(labelList, defaultLabel);
- //int[][] v = cm.getConfusionMatrix();
- cm.putCount("Label1", "Label1", values[0][0]);
- cm.putCount("Label1", "Label2", values[0][1]);
- cm.putCount("Label2", "Label1", values[1][0]);
- cm.putCount("Label2", "Label2", values[1][1]);
- cm.putCount("Label1", DEFAULT_LABEL, OTHER[0]);
- cm.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
- return cm;
+ ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, defaultLabel);
+
+ confusionMatrix.putCount("Label1", "Label1", values[0][0]);
+ confusionMatrix.putCount("Label1", "Label2", values[0][1]);
+ confusionMatrix.putCount("Label2", "Label1", values[1][0]);
+ confusionMatrix.putCount("Label2", "Label2", values[1][1]);
+ confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER[0]);
+ confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
+ return confusionMatrix;
}
}