You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2016/03/31 14:50:34 UTC

lucene-solr:master: LUCENE-7156 - fixed precision and accuracy calculations

Repository: lucene-solr
Updated Branches:
  refs/heads/master e1b45568b -> d08f327a7


LUCENE-7156 - fixed precision and accuracy calculations


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/d08f327a
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/d08f327a
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/d08f327a

Branch: refs/heads/master
Commit: d08f327a7f7d8b25272fb3e32fd6cc44cec1c03a
Parents: e1b4556
Author: Tommaso Teofili <te...@adobe.com>
Authored: Thu Mar 31 14:45:11 2016 +0200
Committer: Tommaso Teofili <te...@adobe.com>
Committed: Thu Mar 31 14:45:31 2016 +0200

----------------------------------------------------------------------
 .../utils/ConfusionMatrixGenerator.java         | 41 ++++++-----
 .../utils/ConfusionMatrixGeneratorTest.java     | 75 ++++++++++++--------
 2 files changed, 68 insertions(+), 48 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/d08f327a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
index 17f5b21..c9ecc4b 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
@@ -106,11 +106,11 @@ public class ConfusionMatrixGenerator {
                     if (aLong != null) {
                       stringLongMap.put(classified, aLong + 1);
                     } else {
-                      stringLongMap.put(classified, 1l);
+                      stringLongMap.put(classified, 1L);
                     }
                   } else {
                     stringLongMap = new HashMap<>();
-                    stringLongMap.put(classified, 1l);
+                    stringLongMap.put(classified, 1L);
                     counts.put(correctAnswer, stringLongMap);
                   }
 
@@ -225,23 +225,29 @@ public class ConfusionMatrixGenerator {
      */
     public double getAccuracy() {
       if (this.accuracy == -1) {
-        double cc = 0d;
-        double wc = 0d;
-        for (Map.Entry<String, Map<String, Long>> entry : linearizedMatrix.entrySet()) {
-          String correctAnswer = entry.getKey();
-          for (Map.Entry<String, Long> classifiedAnswers : entry.getValue().entrySet()) {
-            Long value = classifiedAnswers.getValue();
-            if (value != null) {
-              if (correctAnswer.equals(classifiedAnswers.getKey())) {
-                cc += value;
-              } else {
-                wc += value;
-              }
+        double tp = 0d;
+        double tn = 0d;
+        double fp = 0d;
+        double fn = 0d;
+        for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
+          String klass = classification.getKey();
+          for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
+            if (klass.equals(entry.getKey())) {
+              tp += entry.getValue();
+            } else {
+              fn += entry.getValue();
+            }
+          }
+          for (Map<String, Long> values : linearizedMatrix.values()) {
+            if (values.containsKey(klass)) {
+              fp += values.get(klass);
+            } else {
+              tn++;
             }
           }
 
         }
-        this.accuracy = cc / (cc + wc);
+        this.accuracy = (tp + tn) / (fp + fn + tp + tn);
       }
       return this.accuracy;
     }
@@ -253,7 +259,7 @@ public class ConfusionMatrixGenerator {
      */
     public double getPrecision() {
       double tp = 0;
-      double fp = -linearizedMatrix.size();
+      double fp = 0;
       for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
         String klass = classification.getKey();
         for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
@@ -268,8 +274,7 @@ public class ConfusionMatrixGenerator {
         }
       }
 
-      return tp + fp > 0 ? tp / (tp + fp) : 0;
-
+      return tp > 0 ? tp / (tp + fp) : 0;
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/d08f327a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
index e582b79..d1966ec 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
@@ -65,12 +65,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
       double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
       assertTrue(avgClassificationTime >= 0d );
-      assertTrue(confusionMatrix.getAccuracy() >= 0d);
-      assertTrue(confusionMatrix.getAccuracy() <= 1d);
-      assertTrue(confusionMatrix.getPrecision() >= 0d);
-      assertTrue(confusionMatrix.getPrecision() <= 1d);
-      assertTrue(confusionMatrix.getRecall() >= 0d);
-      assertTrue(confusionMatrix.getRecall() <= 1d);
+      double accuracy = confusionMatrix.getAccuracy();
+      assertTrue(accuracy >= 0d);
+      assertTrue(accuracy <= 1d);
+      double precision = confusionMatrix.getPrecision();
+      assertTrue(precision >= 0d);
+      assertTrue(precision <= 1d);
+      double recall = confusionMatrix.getRecall();
+      assertTrue(recall >= 0d);
+      assertTrue(recall <= 1d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -90,12 +93,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       assertNotNull(confusionMatrix.getLinearizedMatrix());
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
       assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
-      assertTrue(confusionMatrix.getAccuracy() >= 0d);
-      assertTrue(confusionMatrix.getAccuracy() <= 1d);
-      assertTrue(confusionMatrix.getPrecision() >= 0d);
-      assertTrue(confusionMatrix.getPrecision() <= 1d);
-      assertTrue(confusionMatrix.getRecall() >= 0d);
-      assertTrue(confusionMatrix.getRecall() <= 1d);
+      double accuracy = confusionMatrix.getAccuracy();
+      assertTrue(accuracy >= 0d);
+      assertTrue(accuracy <= 1d);
+      double precision = confusionMatrix.getPrecision();
+      assertTrue(precision >= 0d);
+      assertTrue(precision <= 1d);
+      double recall = confusionMatrix.getRecall();
+      assertTrue(recall >= 0d);
+      assertTrue(recall <= 1d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -115,12 +121,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       assertNotNull(confusionMatrix.getLinearizedMatrix());
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
       assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
-      assertTrue(confusionMatrix.getAccuracy() >= 0d);
-      assertTrue(confusionMatrix.getAccuracy() <= 1d);
-      assertTrue(confusionMatrix.getPrecision() >= 0d);
-      assertTrue(confusionMatrix.getPrecision() <= 1d);
-      assertTrue(confusionMatrix.getRecall() >= 0d);
-      assertTrue(confusionMatrix.getRecall() <= 1d);
+      double accuracy = confusionMatrix.getAccuracy();
+      assertTrue(accuracy >= 0d);
+      assertTrue(accuracy <= 1d);
+      double precision = confusionMatrix.getPrecision();
+      assertTrue(precision >= 0d);
+      assertTrue(precision <= 1d);
+      double recall = confusionMatrix.getRecall();
+      assertTrue(recall >= 0d);
+      assertTrue(recall <= 1d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -140,12 +149,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       assertNotNull(confusionMatrix.getLinearizedMatrix());
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
       assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
-      assertTrue(confusionMatrix.getAccuracy() >= 0d);
-      assertTrue(confusionMatrix.getAccuracy() <= 1d);
-      assertTrue(confusionMatrix.getPrecision() >= 0d);
-      assertTrue(confusionMatrix.getPrecision() <= 1d);
-      assertTrue(confusionMatrix.getRecall() >= 0d);
-      assertTrue(confusionMatrix.getRecall() <= 1d);
+      double accuracy = confusionMatrix.getAccuracy();
+      assertTrue(accuracy >= 0d);
+      assertTrue(accuracy <= 1d);
+      double precision = confusionMatrix.getPrecision();
+      assertTrue(precision >= 0d);
+      assertTrue(precision <= 1d);
+      double recall = confusionMatrix.getRecall();
+      assertTrue(recall >= 0d);
+      assertTrue(recall <= 1d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -165,12 +177,15 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       assertNotNull(confusionMatrix.getLinearizedMatrix());
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
       assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
-      assertTrue(confusionMatrix.getAccuracy() >= 0d);
-      assertTrue(confusionMatrix.getAccuracy() <= 1d);
-      assertTrue(confusionMatrix.getPrecision() >= 0d);
-      assertTrue(confusionMatrix.getPrecision() <= 1d);
-      assertTrue(confusionMatrix.getRecall() >= 0d);
-      assertTrue(confusionMatrix.getRecall() <= 1d);
+      double accuracy = confusionMatrix.getAccuracy();
+      assertTrue(accuracy >= 0d);
+      assertTrue(accuracy <= 1d);
+      double precision = confusionMatrix.getPrecision();
+      assertTrue(precision >= 0d);
+      assertTrue(precision <= 1d);
+      double recall = confusionMatrix.getRecall();
+      assertTrue(recall >= 0d);
+      assertTrue(recall <= 1d);
       assertTrue(confusionMatrix.getPrecision("true") >= 0d);
       assertTrue(confusionMatrix.getPrecision("true") <= 1d);
       assertTrue(confusionMatrix.getPrecision("false") >= 0d);