You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2016/03/31 14:50:34 UTC
lucene-solr:master: LUCENE-7156 - fixed precision and accuracy
calculations
Repository: lucene-solr
Updated Branches:
refs/heads/master e1b45568b -> d08f327a7
LUCENE-7156 - fixed precision and accuracy calculations
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/d08f327a
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/d08f327a
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/d08f327a
Branch: refs/heads/master
Commit: d08f327a7f7d8b25272fb3e32fd6cc44cec1c03a
Parents: e1b4556
Author: Tommaso Teofili <te...@adobe.com>
Authored: Thu Mar 31 14:45:11 2016 +0200
Committer: Tommaso Teofili <te...@adobe.com>
Committed: Thu Mar 31 14:45:31 2016 +0200
----------------------------------------------------------------------
.../utils/ConfusionMatrixGenerator.java | 41 ++++++-----
.../utils/ConfusionMatrixGeneratorTest.java | 75 ++++++++++++--------
2 files changed, 68 insertions(+), 48 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/d08f327a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
index 17f5b21..c9ecc4b 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
@@ -106,11 +106,11 @@ public class ConfusionMatrixGenerator {
if (aLong != null) {
stringLongMap.put(classified, aLong + 1);
} else {
- stringLongMap.put(classified, 1l);
+ stringLongMap.put(classified, 1L);
}
} else {
stringLongMap = new HashMap<>();
- stringLongMap.put(classified, 1l);
+ stringLongMap.put(classified, 1L);
counts.put(correctAnswer, stringLongMap);
}
@@ -225,23 +225,29 @@ public class ConfusionMatrixGenerator {
*/
public double getAccuracy() {
if (this.accuracy == -1) {
- double cc = 0d;
- double wc = 0d;
- for (Map.Entry<String, Map<String, Long>> entry : linearizedMatrix.entrySet()) {
- String correctAnswer = entry.getKey();
- for (Map.Entry<String, Long> classifiedAnswers : entry.getValue().entrySet()) {
- Long value = classifiedAnswers.getValue();
- if (value != null) {
- if (correctAnswer.equals(classifiedAnswers.getKey())) {
- cc += value;
- } else {
- wc += value;
- }
+ double tp = 0d;
+ double tn = 0d;
+ double fp = 0d;
+ double fn = 0d;
+ for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
+ String klass = classification.getKey();
+ for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
+ if (klass.equals(entry.getKey())) {
+ tp += entry.getValue();
+ } else {
+ fn += entry.getValue();
+ }
+ }
+ for (Map<String, Long> values : linearizedMatrix.values()) {
+ if (values.containsKey(klass)) {
+ fp += values.get(klass);
+ } else {
+ tn++;
}
}
}
- this.accuracy = cc / (cc + wc);
+ this.accuracy = (tp + tn) / (fp + fn + tp + tn);
}
return this.accuracy;
}
@@ -253,7 +259,7 @@ public class ConfusionMatrixGenerator {
*/
public double getPrecision() {
double tp = 0;
- double fp = -linearizedMatrix.size();
+ double fp = 0;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey();
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
@@ -268,8 +274,7 @@ public class ConfusionMatrixGenerator {
}
}
- return tp + fp > 0 ? tp / (tp + fp) : 0;
-
+ return tp > 0 ? tp / (tp + fp) : 0;
}
/**
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/d08f327a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
index e582b79..d1966ec 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
@@ -65,12 +65,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(avgClassificationTime >= 0d );
- assertTrue(confusionMatrix.getAccuracy() >= 0d);
- assertTrue(confusionMatrix.getAccuracy() <= 1d);
- assertTrue(confusionMatrix.getPrecision() >= 0d);
- assertTrue(confusionMatrix.getPrecision() <= 1d);
- assertTrue(confusionMatrix.getRecall() >= 0d);
- assertTrue(confusionMatrix.getRecall() <= 1d);
+ double accuracy = confusionMatrix.getAccuracy();
+ assertTrue(accuracy >= 0d);
+ assertTrue(accuracy <= 1d);
+ double precision = confusionMatrix.getPrecision();
+ assertTrue(precision >= 0d);
+ assertTrue(precision <= 1d);
+ double recall = confusionMatrix.getRecall();
+ assertTrue(recall >= 0d);
+ assertTrue(recall <= 1d);
} finally {
if (reader != null) {
reader.close();
@@ -90,12 +93,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
- assertTrue(confusionMatrix.getAccuracy() >= 0d);
- assertTrue(confusionMatrix.getAccuracy() <= 1d);
- assertTrue(confusionMatrix.getPrecision() >= 0d);
- assertTrue(confusionMatrix.getPrecision() <= 1d);
- assertTrue(confusionMatrix.getRecall() >= 0d);
- assertTrue(confusionMatrix.getRecall() <= 1d);
+ double accuracy = confusionMatrix.getAccuracy();
+ assertTrue(accuracy >= 0d);
+ assertTrue(accuracy <= 1d);
+ double precision = confusionMatrix.getPrecision();
+ assertTrue(precision >= 0d);
+ assertTrue(precision <= 1d);
+ double recall = confusionMatrix.getRecall();
+ assertTrue(recall >= 0d);
+ assertTrue(recall <= 1d);
} finally {
if (reader != null) {
reader.close();
@@ -115,12 +121,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
- assertTrue(confusionMatrix.getAccuracy() >= 0d);
- assertTrue(confusionMatrix.getAccuracy() <= 1d);
- assertTrue(confusionMatrix.getPrecision() >= 0d);
- assertTrue(confusionMatrix.getPrecision() <= 1d);
- assertTrue(confusionMatrix.getRecall() >= 0d);
- assertTrue(confusionMatrix.getRecall() <= 1d);
+ double accuracy = confusionMatrix.getAccuracy();
+ assertTrue(accuracy >= 0d);
+ assertTrue(accuracy <= 1d);
+ double precision = confusionMatrix.getPrecision();
+ assertTrue(precision >= 0d);
+ assertTrue(precision <= 1d);
+ double recall = confusionMatrix.getRecall();
+ assertTrue(recall >= 0d);
+ assertTrue(recall <= 1d);
} finally {
if (reader != null) {
reader.close();
@@ -140,12 +149,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
- assertTrue(confusionMatrix.getAccuracy() >= 0d);
- assertTrue(confusionMatrix.getAccuracy() <= 1d);
- assertTrue(confusionMatrix.getPrecision() >= 0d);
- assertTrue(confusionMatrix.getPrecision() <= 1d);
- assertTrue(confusionMatrix.getRecall() >= 0d);
- assertTrue(confusionMatrix.getRecall() <= 1d);
+ double accuracy = confusionMatrix.getAccuracy();
+ assertTrue(accuracy >= 0d);
+ assertTrue(accuracy <= 1d);
+ double precision = confusionMatrix.getPrecision();
+ assertTrue(precision >= 0d);
+ assertTrue(precision <= 1d);
+ double recall = confusionMatrix.getRecall();
+ assertTrue(recall >= 0d);
+ assertTrue(recall <= 1d);
} finally {
if (reader != null) {
reader.close();
@@ -165,12 +177,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
- assertTrue(confusionMatrix.getAccuracy() >= 0d);
- assertTrue(confusionMatrix.getAccuracy() <= 1d);
- assertTrue(confusionMatrix.getPrecision() >= 0d);
- assertTrue(confusionMatrix.getPrecision() <= 1d);
- assertTrue(confusionMatrix.getRecall() >= 0d);
- assertTrue(confusionMatrix.getRecall() <= 1d);
+ double accuracy = confusionMatrix.getAccuracy();
+ assertTrue(accuracy >= 0d);
+ assertTrue(accuracy <= 1d);
+ double precision = confusionMatrix.getPrecision();
+ assertTrue(precision >= 0d);
+ assertTrue(precision <= 1d);
+ double recall = confusionMatrix.getRecall();
+ assertTrue(recall >= 0d);
+ assertTrue(recall <= 1d);
assertTrue(confusionMatrix.getPrecision("true") >= 0d);
assertTrue(confusionMatrix.getPrecision("true") <= 1d);
assertTrue(confusionMatrix.getPrecision("false") >= 0d);