You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2016/04/08 11:07:31 UTC

[2/2] lucene-solr:master: LUCENE-7193 - add generic f1-measure metric to confusion matrix

LUCENE-7193 - add generic f1-measure metric to confusion matrix


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/2507015f
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/2507015f
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/2507015f

Branch: refs/heads/master
Commit: 2507015f4c9a024c65de12a0de263ebe51d3de2a
Parents: b02b026
Author: Tommaso Teofili <to...@apache.org>
Authored: Fri Apr 8 11:02:48 2016 +0200
Committer: Tommaso Teofili <to...@apache.org>
Committed: Fri Apr 8 11:06:43 2016 +0200

----------------------------------------------------------------------
 .../utils/ConfusionMatrixGenerator.java         | 39 +++++++++++++++-----
 .../utils/ConfusionMatrixGeneratorTest.java     | 30 ++++++++++++---
 2 files changed, 54 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/2507015f/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
index c9ecc4b..3dd8ba8 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
@@ -52,16 +52,17 @@ public class ConfusionMatrixGenerator {
    * get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier},
    * generated on the given {@link LeafReader}, class and text fields.
    *
-   * @param reader         the {@link LeafReader} containing the index used for creating the {@link Classifier}
-   * @param classifier     the {@link Classifier} whose confusion matrix has to be generated
-   * @param classFieldName the name of the Lucene field used as the classifier's output
-   * @param textFieldName  the nome the Lucene field used as the classifier's input
-   * @param <T>            the return type of the {@link ClassificationResult} returned by the given {@link Classifier}
+   * @param reader              the {@link LeafReader} containing the index used for creating the {@link Classifier}
+   * @param classifier          the {@link Classifier} whose confusion matrix has to be generated
+   * @param classFieldName      the name of the Lucene field used as the classifier's output
+   * @param textFieldName       the nome the Lucene field used as the classifier's input
+   * @param timeoutMilliseconds timeout to wait before stopping creating the confusion matrix
+   * @param <T>                 the return type of the {@link ClassificationResult} returned by the given {@link Classifier}
    * @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
    * @throws IOException if problems occurr while reading the index or using the classifier
    */
   public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName,
-                                                       String textFieldName) throws IOException {
+                                                       String textFieldName, long timeoutMilliseconds) throws IOException {
 
     ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-"));
 
@@ -72,7 +73,13 @@ public class ConfusionMatrixGenerator {
       TopDocs topDocs = indexSearcher.search(new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE);
       double time = 0d;
 
+      int counter = 0;
       for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
+
+        if (timeoutMilliseconds > 0 && time >= timeoutMilliseconds) {
+          break;
+        }
+
         Document doc = reader.document(scoreDoc.doc);
         String[] correctAnswers = doc.getValues(classFieldName);
 
@@ -91,6 +98,7 @@ public class ConfusionMatrixGenerator {
               if (result != null) {
                 T assignedClass = result.getAssignedClass();
                 if (assignedClass != null) {
+                  counter++;
                   String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
 
                   String correctAnswer;
@@ -117,7 +125,7 @@ public class ConfusionMatrixGenerator {
                 }
               }
             } catch (TimeoutException timeoutException) {
-              // add timeout
+              // add classification timeout
               time += 5000;
             } catch (ExecutionException | InterruptedException executionException) {
               throw new RuntimeException(executionException);
@@ -126,7 +134,7 @@ public class ConfusionMatrixGenerator {
           }
         }
       }
-      return new ConfusionMatrix(counts, time / topDocs.totalHits, topDocs.totalHits);
+      return new ConfusionMatrix(counts, time / counter, counter);
     } finally {
       executorService.shutdown();
     }
@@ -167,7 +175,7 @@ public class ConfusionMatrixGenerator {
     public double getPrecision(String klass) {
       Map<String, Long> classifications = linearizedMatrix.get(klass);
       double tp = 0;
-      double fp = -1;
+      double fp = 0;
       if (classifications != null) {
         for (Map.Entry<String, Long> entry : classifications.entrySet()) {
           if (klass.equals(entry.getKey())) {
@@ -180,7 +188,7 @@ public class ConfusionMatrixGenerator {
           }
         }
       }
-      return tp + fp > 0 ? tp / (tp + fp) : 0;
+      return tp > 0 ? tp / (tp + fp) : 0;
     }
 
     /**
@@ -218,6 +226,17 @@ public class ConfusionMatrixGenerator {
     }
 
     /**
+     * get the F-1 measure on this confusion matrix
+     *
+     * @return the F-1 measure
+     */
+    public double getF1Measure() {
+      double recall = getRecall();
+      double precision = getPrecision();
+      return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
+    }
+
+    /**
      * Calculate accuracy on this confusion matrix using the formula:
      * {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)}
      *

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/2507015f/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
index d1966ec..63cce2a 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
@@ -59,7 +59,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
           return null;
         }
       };
-      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
+          classifier, categoryFieldName, textFieldName, -1);
       assertNotNull(confusionMatrix);
       assertNotNull(confusionMatrix.getLinearizedMatrix());
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@@ -74,6 +75,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       double recall = confusionMatrix.getRecall();
       assertTrue(recall >= 0d);
       assertTrue(recall <= 1d);
+      double f1Measure = confusionMatrix.getF1Measure();
+      assertTrue(f1Measure >= 0d);
+      assertTrue(f1Measure <= 1d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -88,7 +92,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       MockAnalyzer analyzer = new MockAnalyzer(random());
       reader = getSampleIndex(analyzer);
       Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
-      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
+          classifier, categoryFieldName, textFieldName, -1);
       assertNotNull(confusionMatrix);
       assertNotNull(confusionMatrix.getLinearizedMatrix());
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@@ -102,6 +107,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       double recall = confusionMatrix.getRecall();
       assertTrue(recall >= 0d);
       assertTrue(recall <= 1d);
+      double f1Measure = confusionMatrix.getF1Measure();
+      assertTrue(f1Measure >= 0d);
+      assertTrue(f1Measure <= 1d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -116,7 +124,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       MockAnalyzer analyzer = new MockAnalyzer(random());
       reader = getSampleIndex(analyzer);
       Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
-      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
+          classifier, categoryFieldName, textFieldName, -1);
       assertNotNull(confusionMatrix);
       assertNotNull(confusionMatrix.getLinearizedMatrix());
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@@ -130,6 +139,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       double recall = confusionMatrix.getRecall();
       assertTrue(recall >= 0d);
       assertTrue(recall <= 1d);
+      double f1Measure = confusionMatrix.getF1Measure();
+      assertTrue(f1Measure >= 0d);
+      assertTrue(f1Measure <= 1d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -144,7 +156,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       MockAnalyzer analyzer = new MockAnalyzer(random());
       reader = getSampleIndex(analyzer);
       Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
-      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
+          classifier, categoryFieldName, textFieldName, -1);
       assertNotNull(confusionMatrix);
       assertNotNull(confusionMatrix.getLinearizedMatrix());
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@@ -158,6 +171,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       double recall = confusionMatrix.getRecall();
       assertTrue(recall >= 0d);
       assertTrue(recall <= 1d);
+      double f1Measure = confusionMatrix.getF1Measure();
+      assertTrue(f1Measure >= 0d);
+      assertTrue(f1Measure <= 1d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -172,7 +188,8 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       MockAnalyzer analyzer = new MockAnalyzer(random());
       reader = getSampleIndex(analyzer);
       Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName);
-      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, booleanFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
+          classifier, booleanFieldName, textFieldName, -1);
       assertNotNull(confusionMatrix);
       assertNotNull(confusionMatrix.getLinearizedMatrix());
       assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
@@ -186,6 +203,9 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       double recall = confusionMatrix.getRecall();
       assertTrue(recall >= 0d);
       assertTrue(recall <= 1d);
+      double f1Measure = confusionMatrix.getF1Measure();
+      assertTrue(f1Measure >= 0d);
+      assertTrue(f1Measure <= 1d);
       assertTrue(confusionMatrix.getPrecision("true") >= 0d);
       assertTrue(confusionMatrix.getPrecision("true") <= 1d);
       assertTrue(confusionMatrix.getPrecision("false") >= 0d);