You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2015/10/23 17:49:04 UTC

svn commit: r1710249 - in /lucene/dev/trunk/lucene/classification/src: java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java

Author: tommaso
Date: Fri Oct 23 15:49:04 2015
New Revision: 1710249

URL: http://svn.apache.org/viewvc?rev=1710249&view=rev
Log:
LUCENE-6854 - added precision, recall, f1 measure metrics to ConfusionMatrix

Modified:
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java?rev=1710249&r1=1710248&r2=1710249&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java Fri Oct 23 15:49:04 2015
@@ -32,12 +32,10 @@ import org.apache.lucene.classification.
 import org.apache.lucene.classification.Classifier;
 import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.StoredDocument;
-import org.apache.lucene.index.Term;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.TermRangeQuery;
 import org.apache.lucene.search.TopDocs;
-import org.apache.lucene.search.WildcardQuery;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.NamedThreadFactory;
 
@@ -152,6 +150,7 @@ public class ConfusionMatrixGenerator {
 
     /**
      * get the linearized confusion matrix as a {@link Map}
+     *
      * @return a {@link Map} whose keys are the correct classification answers and whose values are the actual answers'
      * counts
      */
@@ -160,6 +159,61 @@ public class ConfusionMatrixGenerator {
     }
 
     /**
+     * calculate precision on the given class
+     *
+     * @param klass the class to calculate the precision for
+     * @return the precision for the given class
+     */
+    public double getPrecision(String klass) {
+      Map<String, Long> classifications = linearizedMatrix.get(klass);
+      double tp = 0;
+      double fp = 0;
+      for (Map.Entry<String, Long> entry : classifications.entrySet()) {
+        if (klass.equals(entry.getKey())) {
+          tp += entry.getValue();
+        }
+      }
+      for (Map<String, Long> values : linearizedMatrix.values()) {
+        if (values.containsKey(klass)) {
+          fp += values.get(klass);
+        }
+      }
+      return tp / (tp + fp);
+    }
+
+    /**
+     * calculate recall on the given class
+     *
+     * @param klass the class to calculate the recall for
+     * @return the recall for the given class
+     */
+    public double getRecall(String klass) {
+      Map<String, Long> classifications = linearizedMatrix.get(klass);
+      double tp = 0;
+      double fn = 0;
+      for (Map.Entry<String, Long> entry : classifications.entrySet()) {
+        if (klass.equals(entry.getKey())) {
+          tp += entry.getValue();
+        } else {
+          fn += entry.getValue();
+        }
+      }
+      return tp / (tp + fn);
+    }
+
+    /**
+     * get the F-1 measure of the given class
+     *
+     * @param klass the class to calculate the F-1 measure for
+     * @return the F-1 measure for the given class
+     */
+    public double getF1Measure(String klass) {
+      double recall = getRecall(klass);
+      double precision = getPrecision(klass);
+      return 2 * precision * recall / (precision + recall);
+    }
+
+    /**
      * Calculate accuracy on this confusion matrix using the formula:
      * {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)}
      *
@@ -199,6 +253,7 @@ public class ConfusionMatrixGenerator {
 
     /**
      * get the average classification time in milliseconds
+     *
      * @return the avg classification time
      */
     public double getAvgClassificationTime() {
@@ -207,6 +262,7 @@ public class ConfusionMatrixGenerator {
 
     /**
      * get the no. of documents evaluated while generating this confusion matrix
+     *
      * @return the no. of documents evaluated
      */
     public int getNumberOfEvaluatedDocs() {

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java?rev=1710249&r1=1710248&r2=1710249&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java Fri Oct 23 15:49:04 2015
@@ -145,6 +145,12 @@ public class ConfusionMatrixGeneratorTes
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
       assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
       assertTrue(confusionMatrix.getAccuracy() > 0d);
+      assertTrue(confusionMatrix.getPrecision("true") > 0d);
+      assertTrue(confusionMatrix.getPrecision("false") > 0d);
+      assertTrue(confusionMatrix.getRecall("true") > 0d);
+      assertTrue(confusionMatrix.getRecall("false") > 0d);
+      assertTrue(confusionMatrix.getF1Measure("true") > 0d);
+      assertTrue(confusionMatrix.getF1Measure("false") > 0d);
     } finally {
       if (reader != null) {
         reader.close();