You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2015/10/23 17:49:04 UTC
svn commit: r1710249 - in /lucene/dev/trunk/lucene/classification/src:
java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
Author: tommaso
Date: Fri Oct 23 15:49:04 2015
New Revision: 1710249
URL: http://svn.apache.org/viewvc?rev=1710249&view=rev
Log:
LUCENE-6854 - added precision, recall, f1 measure metrics to ConfusionMatrix
Modified:
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java?rev=1710249&r1=1710248&r2=1710249&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java Fri Oct 23 15:49:04 2015
@@ -32,12 +32,10 @@ import org.apache.lucene.classification.
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.StoredDocument;
-import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermRangeQuery;
import org.apache.lucene.search.TopDocs;
-import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NamedThreadFactory;
@@ -152,6 +150,7 @@ public class ConfusionMatrixGenerator {
/**
* get the linearized confusion matrix as a {@link Map}
+ *
* @return a {@link Map} whose keys are the correct classification answers and whose values are the actual answers'
* counts
*/
@@ -160,6 +159,61 @@ public class ConfusionMatrixGenerator {
}
/**
+ * calculate precision on the given class
+ *
+ * @param klass the class to calculate the precision for
+ * @return the precision for the given class
+ */
+ public double getPrecision(String klass) {
+ Map<String, Long> classifications = linearizedMatrix.get(klass);
+ double tp = 0;
+ double fp = 0;
+ for (Map.Entry<String, Long> entry : classifications.entrySet()) {
+ if (klass.equals(entry.getKey())) {
+ tp += entry.getValue();
+ }
+ }
+ for (Map<String, Long> values : linearizedMatrix.values()) {
+ if (values.containsKey(klass)) {
+ fp += values.get(klass);
+ }
+ }
+ return tp / (tp + fp);
+ }
+
+ /**
+ * calculate recall on the given class
+ *
+ * @param klass the class to calculate the recall for
+ * @return the recall for the given class
+ */
+ public double getRecall(String klass) {
+ Map<String, Long> classifications = linearizedMatrix.get(klass);
+ double tp = 0;
+ double fn = 0;
+ for (Map.Entry<String, Long> entry : classifications.entrySet()) {
+ if (klass.equals(entry.getKey())) {
+ tp += entry.getValue();
+ } else {
+ fn += entry.getValue();
+ }
+ }
+ return tp / (tp + fn);
+ }
+
+ /**
+ * get the F-1 measure of the given class
+ *
+ * @param klass the class to calculate the F-1 measure for
+ * @return the F-1 measure for the given class
+ */
+ public double getF1Measure(String klass) {
+ double recall = getRecall(klass);
+ double precision = getPrecision(klass);
+ return 2 * precision * recall / (precision + recall);
+ }
+
+ /**
* Calculate accuracy on this confusion matrix using the formula:
* {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)}
*
@@ -199,6 +253,7 @@ public class ConfusionMatrixGenerator {
/**
* get the average classification time in milliseconds
+ *
* @return the avg classification time
*/
public double getAvgClassificationTime() {
@@ -207,6 +262,7 @@ public class ConfusionMatrixGenerator {
/**
* get the no. of documents evaluated while generating this confusion matrix
+ *
* @return the no. of documents evaluated
*/
public int getNumberOfEvaluatedDocs() {
Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java?rev=1710249&r1=1710248&r2=1710249&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java Fri Oct 23 15:49:04 2015
@@ -145,6 +145,12 @@ public class ConfusionMatrixGeneratorTes
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
assertTrue(confusionMatrix.getAccuracy() > 0d);
+ assertTrue(confusionMatrix.getPrecision("true") > 0d);
+ assertTrue(confusionMatrix.getPrecision("false") > 0d);
+ assertTrue(confusionMatrix.getRecall("true") > 0d);
+ assertTrue(confusionMatrix.getRecall("false") > 0d);
+ assertTrue(confusionMatrix.getF1Measure("true") > 0d);
+ assertTrue(confusionMatrix.getF1Measure("false") > 0d);
} finally {
if (reader != null) {
reader.close();