You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2017/05/18 12:36:41 UTC
lucene-solr:master: LUCENE-7838 - added knn classifier based on flt
Repository: lucene-solr
Updated Branches:
refs/heads/master afd70a48c -> bd9e32d35
LUCENE-7838 - added knn classifier based on flt
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/bd9e32d3
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/bd9e32d3
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/bd9e32d3
Branch: refs/heads/master
Commit: bd9e32d358399af7c31e732314e1ef1dd89bcfa1
Parents: afd70a4
Author: Tommaso Teofili <to...@apache.org>
Authored: Thu May 18 14:35:53 2017 +0200
Committer: Tommaso Teofili <to...@apache.org>
Committed: Thu May 18 14:36:18 2017 +0200
----------------------------------------------------------------------
.../lucene/classification/classification.iml | 3 +-
lucene/classification/build.xml | 8 +-
.../classification/KNearestFuzzyClassifier.java | 225 +++++++++++++++++++
.../classification/utils/DatasetSplitter.java | 2 +-
.../KNearestFuzzyClassifierTest.java | 124 ++++++++++
.../utils/ConfusionMatrixGeneratorTest.java | 123 +++++-----
6 files changed, 417 insertions(+), 68 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/dev-tools/idea/lucene/classification/classification.iml
----------------------------------------------------------------------
diff --git a/dev-tools/idea/lucene/classification/classification.iml b/dev-tools/idea/lucene/classification/classification.iml
index 0f20274..44af1e4 100644
--- a/dev-tools/idea/lucene/classification/classification.iml
+++ b/dev-tools/idea/lucene/classification/classification.iml
@@ -16,8 +16,9 @@
<orderEntry type="module" scope="TEST" module-name="lucene-test-framework" />
<orderEntry type="module" module-name="lucene-core" />
<orderEntry type="module" module-name="queries" />
- <orderEntry type="module" scope="TEST" module-name="analysis-common" />
+ <orderEntry type="module" module-name="analysis-common" />
<orderEntry type="module" module-name="grouping" />
<orderEntry type="module" module-name="misc" />
+ <orderEntry type="module" module-name="sandbox" />
</component>
</module>
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/lucene/classification/build.xml
----------------------------------------------------------------------
diff --git a/lucene/classification/build.xml b/lucene/classification/build.xml
index 704cae8..b3f1bfd 100644
--- a/lucene/classification/build.xml
+++ b/lucene/classification/build.xml
@@ -28,6 +28,8 @@
<path refid="base.classpath"/>
<pathelement path="${queries.jar}"/>
<pathelement path="${grouping.jar}"/>
+ <pathelement path="${sandbox.jar}"/>
+ <pathelement path="${analyzers-common.jar}"/>
</path>
<path id="test.classpath">
@@ -36,16 +38,18 @@
<path refid="test.base.classpath"/>
</path>
- <target name="compile-core" depends="jar-grouping,jar-queries,jar-analyzers-common,common.compile-core" />
+ <target name="compile-core" depends="jar-sandbox,jar-grouping,jar-queries,jar-analyzers-common,common.compile-core" />
<target name="jar-core" depends="common.jar-core" />
- <target name="javadocs" depends="javadocs-grouping,compile-core,check-javadocs-uptodate"
+ <target name="javadocs" depends="javadocs-sandbox,javadocs-grouping,compile-core,check-javadocs-uptodate"
unless="javadocs-uptodate-${name}">
<invoke-module-javadoc>
<links>
<link href="../queries"/>
+ <link href="../analyzers/common"/>
<link href="../grouping"/>
+ <link href="../sandbox"/>
</links>
</invoke-module-javadoc>
</target>
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java
new file mode 100644
index 0000000..1cde468
--- /dev/null
+++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.classification;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexableField;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.sandbox.queries.FuzzyLikeThisQuery;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.WildcardQuery;
+import org.apache.lucene.search.similarities.BM25Similarity;
+import org.apache.lucene.search.similarities.ClassicSimilarity;
+import org.apache.lucene.search.similarities.Similarity;
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * A k-Nearest Neighbor classifier based on {@link FuzzyLikeThisQuery}.
+ *
+ * @lucene.experimental
+ */
+public class KNearestFuzzyClassifier implements Classifier<BytesRef> {
+
+ /**
+ * the name of the fields used as the input text
+ */
+ protected final String[] textFieldNames;
+
+ /**
+ * the name of the field used as the output text
+ */
+ protected final String classFieldName;
+
+ /**
+ * an {@link IndexSearcher} used to perform queries
+ */
+ protected final IndexSearcher indexSearcher;
+
+ /**
+ * the no. of docs to compare in order to find the nearest neighbor to the input text
+ */
+ protected final int k;
+
+ /**
+ * a {@link Query} used to filter the documents that should be used from this classifier's underlying {@link LeafReader}
+ */
+ protected final Query query;
+ private final Analyzer analyzer;
+
+ /**
+ * Creates a {@link KNearestFuzzyClassifier}.
+ *
+ * @param indexReader the reader on the index to be used for classification
+ * @param analyzer an {@link Analyzer} used to analyze unseen text
+ * @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
+ * (defaults to {@link BM25Similarity})
+ * @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
+ * if all the indexed docs should be used
+ * @param k the no. of docs to select in the MLT results to find the nearest neighbor
+ * @param classFieldName the name of the field used as the output for the classifier
+ * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
+ */
+ public KNearestFuzzyClassifier(IndexReader indexReader, Similarity similarity, Analyzer analyzer, Query query, int k,
+ String classFieldName, String... textFieldNames) {
+ this.textFieldNames = textFieldNames;
+ this.classFieldName = classFieldName;
+ this.analyzer = analyzer;
+ this.indexSearcher = new IndexSearcher(indexReader);
+ if (similarity != null) {
+ this.indexSearcher.setSimilarity(similarity);
+ } else {
+ this.indexSearcher.setSimilarity(new BM25Similarity());
+ }
+ this.query = query;
+ this.k = k;
+ }
+
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
+ TopDocs knnResults = knnSearch(text);
+ List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
+ ClassificationResult<BytesRef> assignedClass = null;
+ double maxscore = -Double.MAX_VALUE;
+ for (ClassificationResult<BytesRef> cl : assignedClasses) {
+ if (cl.getScore() > maxscore) {
+ assignedClass = cl;
+ maxscore = cl.getScore();
+ }
+ }
+ return assignedClass;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
+ TopDocs knnResults = knnSearch(text);
+ List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
+ Collections.sort(assignedClasses);
+ return assignedClasses;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
+ TopDocs knnResults = knnSearch(text);
+ List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
+ Collections.sort(assignedClasses);
+ return assignedClasses.subList(0, max);
+ }
+
+ private TopDocs knnSearch(String text) throws IOException {
+ BooleanQuery.Builder bq = new BooleanQuery.Builder();
+ FuzzyLikeThisQuery fuzzyLikeThisQuery = new FuzzyLikeThisQuery(300, analyzer);
+ for (String fieldName : textFieldNames) {
+ fuzzyLikeThisQuery.addTerms(text, fieldName, 1f, 2); // TODO: make this parameters configurable
+ }
+ bq.add(fuzzyLikeThisQuery, BooleanClause.Occur.MUST);
+ Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
+ bq.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
+ if (query != null) {
+ bq.add(query, BooleanClause.Occur.MUST);
+ }
+ return indexSearcher.search(bq.build(), k);
+ }
+
+ /**
+ * build a list of classification results from search results
+ *
+ * @param topDocs the search results as a {@link TopDocs} object
+ * @return a {@link List} of {@link ClassificationResult}, one for each existing class
+ * @throws IOException if it's not possible to get the stored value of class field
+ */
+ protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
+ Map<BytesRef, Integer> classCounts = new HashMap<>();
+ Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
+ float maxScore = topDocs.getMaxScore();
+ for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
+ IndexableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
+ if (storableField != null) {
+ BytesRef cl = new BytesRef(storableField.stringValue());
+ //update count
+ Integer count = classCounts.get(cl);
+ if (count != null) {
+ classCounts.put(cl, count + 1);
+ } else {
+ classCounts.put(cl, 1);
+ }
+ //update boost, the boost is based on the best score
+ Double totalBoost = classBoosts.get(cl);
+ double singleBoost = scoreDoc.score / maxScore;
+ if (totalBoost != null) {
+ classBoosts.put(cl, totalBoost + singleBoost);
+ } else {
+ classBoosts.put(cl, singleBoost);
+ }
+ }
+ }
+ List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
+ List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
+ int sumdoc = 0;
+ for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
+ Integer count = entry.getValue();
+ Double normBoost = classBoosts.get(entry.getKey()) / count; //the boost is normalized to be 0<b<1
+ temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), (count * normBoost) / (double) k));
+ sumdoc += count;
+ }
+
+ //correction
+ if (sumdoc < k) {
+ for (ClassificationResult<BytesRef> cr : temporaryList) {
+ returnList.add(new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
+ }
+ } else {
+ returnList = temporaryList;
+ }
+ return returnList;
+ }
+
+ @Override
+ public String toString() {
+ return "KNearestFuzzyClassifier{" +
+ "textFieldNames=" + Arrays.toString(textFieldNames) +
+ ", classFieldName='" + classFieldName + '\'' +
+ ", k=" + k +
+ ", query=" + query +
+ ", similarity=" + indexSearcher.getSimilarity(true) +
+ '}';
+ }
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
index 7ab674e..913fb7f 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
@@ -121,7 +121,7 @@ public class DatasetSplitter {
int b = 0;
// iterate over existing documents
- for (GroupDocs group : topGroups.groups) {
+ for (GroupDocs<Object> group : topGroups.groups) {
int totalHits = group.totalHits;
double testSize = totalHits * testRatio;
int tc = 0;
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java
new file mode 100644
index 0000000..6e4c404
--- /dev/null
+++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.classification;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.analysis.Tokenizer;
+import org.apache.lucene.analysis.core.KeywordTokenizer;
+import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
+import org.apache.lucene.analysis.reverse.ReverseStringFilter;
+import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.util.BytesRef;
+import org.junit.Test;
+
+/**
+ * Testcase for {@link KNearestFuzzyClassifier}
+ */
+public class KNearestFuzzyClassifierTest extends ClassificationTestBase<BytesRef> {
+
+ @Test
+ public void testBasicUsage() throws Exception {
+ LeafReader leafReader = null;
+ try {
+ MockAnalyzer analyzer = new MockAnalyzer(random());
+ leafReader = getSampleIndex(analyzer);
+ Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null, analyzer, null, 3, categoryFieldName, textFieldName);
+ checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+ checkCorrectClassification(classifier, POLITICS_INPUT, POLITICS_RESULT);
+ } finally {
+ if (leafReader != null) {
+ leafReader.close();
+ }
+ }
+ }
+
+ @Test
+ public void testBasicUsageWithQuery() throws Exception {
+ LeafReader leafReader = null;
+ try {
+ MockAnalyzer analyzer = new MockAnalyzer(random());
+ leafReader = getSampleIndex(analyzer);
+ TermQuery query = new TermQuery(new Term(textFieldName, "not"));
+ Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null, analyzer, query, 3, categoryFieldName, textFieldName);
+ checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+ } finally {
+ if (leafReader != null) {
+ leafReader.close();
+ }
+ }
+ }
+
+ @Test
+ public void testPerformance() throws Exception {
+ MockAnalyzer analyzer = new MockAnalyzer(random());
+ LeafReader leafReader = getRandomIndex(analyzer, 100);
+ try {
+ long trainStart = System.currentTimeMillis();
+ Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null, analyzer, null, 3, categoryFieldName, textFieldName);
+ long trainEnd = System.currentTimeMillis();
+ long trainTime = trainEnd - trainStart;
+ assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
+
+ long evaluationStart = System.currentTimeMillis();
+ ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
+ classifier, categoryFieldName, textFieldName, -1);
+ assertNotNull(confusionMatrix);
+ long evaluationEnd = System.currentTimeMillis();
+ long evaluationTime = evaluationEnd - evaluationStart;
+ assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
+ double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
+ assertTrue(5000 > avgClassificationTime);
+ double accuracy = confusionMatrix.getAccuracy();
+ assertTrue(accuracy >= 0d);
+ assertTrue(accuracy <= 1d);
+
+ double recall = confusionMatrix.getRecall();
+ assertTrue(recall >= 0d);
+ assertTrue(recall <= 1d);
+
+ double precision = confusionMatrix.getPrecision();
+ assertTrue(precision >= 0d);
+ assertTrue(precision <= 1d);
+
+ Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
+ TermsEnum iterator = terms.iterator();
+ BytesRef term;
+ while ((term = iterator.next()) != null) {
+ String s = term.utf8ToString();
+ recall = confusionMatrix.getRecall(s);
+ assertTrue(recall >= 0d);
+ assertTrue(recall <= 1d);
+ precision = confusionMatrix.getPrecision(s);
+ assertTrue(precision >= 0d);
+ assertTrue(precision <= 1d);
+ double f1Measure = confusionMatrix.getF1Measure(s);
+ assertTrue(f1Measure >= 0d);
+ assertTrue(f1Measure <= 1d);
+ }
+ } finally {
+ leafReader.close();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
index 63cce2a..edb76b5 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
@@ -21,11 +21,13 @@ import java.io.IOException;
import java.util.List;
import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.classification.BM25NBClassifier;
import org.apache.lucene.classification.BooleanPerceptronClassifier;
import org.apache.lucene.classification.CachingNaiveBayesClassifier;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.ClassificationTestBase;
import org.apache.lucene.classification.Classifier;
+import org.apache.lucene.classification.KNearestFuzzyClassifier;
import org.apache.lucene.classification.KNearestNeighborClassifier;
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
import org.apache.lucene.index.LeafReader;
@@ -94,22 +96,43 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
- assertNotNull(confusionMatrix);
- assertNotNull(confusionMatrix.getLinearizedMatrix());
- assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
- assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
- double accuracy = confusionMatrix.getAccuracy();
- assertTrue(accuracy >= 0d);
- assertTrue(accuracy <= 1d);
- double precision = confusionMatrix.getPrecision();
- assertTrue(precision >= 0d);
- assertTrue(precision <= 1d);
- double recall = confusionMatrix.getRecall();
- assertTrue(recall >= 0d);
- assertTrue(recall <= 1d);
- double f1Measure = confusionMatrix.getF1Measure();
- assertTrue(f1Measure >= 0d);
- assertTrue(f1Measure <= 1d);
+ checkCM(confusionMatrix);
+ } finally {
+ if (reader != null) {
+ reader.close();
+ }
+ }
+ }
+
+ private void checkCM(ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix) {
+ assertNotNull(confusionMatrix);
+ assertNotNull(confusionMatrix.getLinearizedMatrix());
+ assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
+ assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
+ double accuracy = confusionMatrix.getAccuracy();
+ assertTrue(accuracy >= 0d);
+ assertTrue(accuracy <= 1d);
+ double precision = confusionMatrix.getPrecision();
+ assertTrue(precision >= 0d);
+ assertTrue(precision <= 1d);
+ double recall = confusionMatrix.getRecall();
+ assertTrue(recall >= 0d);
+ assertTrue(recall <= 1d);
+ double f1Measure = confusionMatrix.getF1Measure();
+ assertTrue(f1Measure >= 0d);
+ assertTrue(f1Measure <= 1d);
+ }
+
+ @Test
+ public void testGetConfusionMatrixWithBM25NB() throws Exception {
+ LeafReader reader = null;
+ try {
+ MockAnalyzer analyzer = new MockAnalyzer(random());
+ reader = getSampleIndex(analyzer);
+ Classifier<BytesRef> classifier = new BM25NBClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
+ ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
+ classifier, categoryFieldName, textFieldName, -1);
+ checkCM(confusionMatrix);
} finally {
if (reader != null) {
reader.close();
@@ -126,22 +149,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
- assertNotNull(confusionMatrix);
- assertNotNull(confusionMatrix.getLinearizedMatrix());
- assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
- assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
- double accuracy = confusionMatrix.getAccuracy();
- assertTrue(accuracy >= 0d);
- assertTrue(accuracy <= 1d);
- double precision = confusionMatrix.getPrecision();
- assertTrue(precision >= 0d);
- assertTrue(precision <= 1d);
- double recall = confusionMatrix.getRecall();
- assertTrue(recall >= 0d);
- assertTrue(recall <= 1d);
- double f1Measure = confusionMatrix.getF1Measure();
- assertTrue(f1Measure >= 0d);
- assertTrue(f1Measure <= 1d);
+ checkCM(confusionMatrix);
} finally {
if (reader != null) {
reader.close();
@@ -158,22 +166,24 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName, -1);
- assertNotNull(confusionMatrix);
- assertNotNull(confusionMatrix.getLinearizedMatrix());
- assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
- assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
- double accuracy = confusionMatrix.getAccuracy();
- assertTrue(accuracy >= 0d);
- assertTrue(accuracy <= 1d);
- double precision = confusionMatrix.getPrecision();
- assertTrue(precision >= 0d);
- assertTrue(precision <= 1d);
- double recall = confusionMatrix.getRecall();
- assertTrue(recall >= 0d);
- assertTrue(recall <= 1d);
- double f1Measure = confusionMatrix.getF1Measure();
- assertTrue(f1Measure >= 0d);
- assertTrue(f1Measure <= 1d);
+ checkCM(confusionMatrix);
+ } finally {
+ if (reader != null) {
+ reader.close();
+ }
+ }
+ }
+
+ @Test
+ public void testGetConfusionMatrixWithFLTKNN() throws Exception {
+ LeafReader reader = null;
+ try {
+ MockAnalyzer analyzer = new MockAnalyzer(random());
+ reader = getSampleIndex(analyzer);
+ Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(reader, null, analyzer, null, 1, categoryFieldName, textFieldName);
+ ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
+ classifier, categoryFieldName, textFieldName, -1);
+ checkCM(confusionMatrix);
} finally {
if (reader != null) {
reader.close();
@@ -190,22 +200,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName);
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, booleanFieldName, textFieldName, -1);
- assertNotNull(confusionMatrix);
- assertNotNull(confusionMatrix.getLinearizedMatrix());
- assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
- assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
- double accuracy = confusionMatrix.getAccuracy();
- assertTrue(accuracy >= 0d);
- assertTrue(accuracy <= 1d);
- double precision = confusionMatrix.getPrecision();
- assertTrue(precision >= 0d);
- assertTrue(precision <= 1d);
- double recall = confusionMatrix.getRecall();
- assertTrue(recall >= 0d);
- assertTrue(recall <= 1d);
- double f1Measure = confusionMatrix.getF1Measure();
- assertTrue(f1Measure >= 0d);
- assertTrue(f1Measure <= 1d);
+ checkCM(confusionMatrix);
assertTrue(confusionMatrix.getPrecision("true") >= 0d);
assertTrue(confusionMatrix.getPrecision("true") <= 1d);
assertTrue(confusionMatrix.getPrecision("false") >= 0d);