You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2013/10/17 10:14:03 UTC
svn commit: r1533003 - in
/lucene/dev/branches/branch_4x/lucene/classification/src:
java/org/apache/lucene/classification/ test/org/apache/lucene/classification/
Author: tommaso
Date: Thu Oct 17 08:14:03 2013
New Revision: 1533003
URL: http://svn.apache.org/r1533003
Log:
LUCENE-5290 - backport LUCENE-5284 and performance tests from trunk
Modified:
lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
Modified: lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java Thu Oct 17 08:14:03 2013
@@ -18,6 +18,7 @@ package org.apache.lucene.classification
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.AtomicReader;
+import org.apache.lucene.search.Query;
import java.io.IOException;
@@ -47,4 +48,16 @@ public interface Classifier<T> {
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
throws IOException;
+ /**
+ * Train the classifier using the underlying Lucene index
+ * @param atomicReader the reader to use to access the Lucene index
+ * @param textFieldName the name of the field used to compare documents
+ * @param classFieldName the name of the field containing the class assigned to documents
+ * @param analyzer the analyzer used to tokenize / filter the unseen text
+ * @param query the query to filter which documents use for training
+ * @throws IOException If there is a low-level I/O error.
+ */
+ public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
+ throws IOException;
+
}
Modified: lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java Thu Oct 17 08:14:03 2013
@@ -19,6 +19,8 @@ package org.apache.lucene.classification
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.queries.mlt.MoreLikeThis;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
@@ -43,6 +45,7 @@ public class KNearestNeighborClassifier
private String classFieldName;
private IndexSearcher indexSearcher;
private int k;
+ private Query query;
/**
* Create a {@link Classifier} using kNN algorithm
@@ -59,9 +62,18 @@ public class KNearestNeighborClassifier
@Override
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
if (mlt == null) {
- throw new IOException("You must first call Classifier#train first");
+ throw new IOException("You must first call Classifier#train");
+ }
+ Query q;
+ if (query != null) {
+ Query mltQuery = mlt.like(new StringReader(text), textFieldName);
+ BooleanQuery bq = new BooleanQuery();
+ bq.add(query, BooleanClause.Occur.MUST);
+ bq.add(mltQuery, BooleanClause.Occur.MUST);
+ q = bq;
+ } else {
+ q = mlt.like(new StringReader(text), textFieldName);
}
- Query q = mlt.like(new StringReader(text), textFieldName);
TopDocs topDocs = indexSearcher.search(q, k);
return selectClassFromNeighbors(topDocs);
}
@@ -71,13 +83,11 @@ public class KNearestNeighborClassifier
Map<BytesRef, Integer> classCounts = new HashMap<BytesRef, Integer>();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
- if (cl != null) {
- Integer count = classCounts.get(cl);
- if (count != null) {
- classCounts.put(cl, count + 1);
- } else {
- classCounts.put(cl, 1);
- }
+ Integer count = classCounts.get(cl);
+ if (count != null) {
+ classCounts.put(cl, count + 1);
+ } else {
+ classCounts.put(cl, 1);
}
}
double max = 0;
@@ -98,11 +108,20 @@ public class KNearestNeighborClassifier
*/
@Override
public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
+ train(atomicReader, textFieldName, classFieldName, analyzer, null);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textFieldName = textFieldName;
this.classFieldName = classFieldName;
mlt = new MoreLikeThis(atomicReader);
mlt.setAnalyzer(analyzer);
mlt.setFieldNames(new String[]{textFieldName});
indexSearcher = new IndexSearcher(atomicReader);
+ this.query = query;
}
-}
+}
\ No newline at end of file
Modified: lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java Thu Oct 17 08:14:03 2013
@@ -27,6 +27,7 @@ import org.apache.lucene.index.TermsEnum
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery;
@@ -50,6 +51,7 @@ public class SimpleNaiveBayesClassifier
private int docsWithClassSize;
private Analyzer analyzer;
private IndexSearcher indexSearcher;
+ private Query query;
/**
* Creates a new NaiveBayes classifier.
@@ -63,7 +65,7 @@ public class SimpleNaiveBayesClassifier
* {@inheritDoc}
*/
@Override
- public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
+ public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
throws IOException {
this.atomicReader = atomicReader;
this.indexSearcher = new IndexSearcher(this.atomicReader);
@@ -71,13 +73,29 @@ public class SimpleNaiveBayesClassifier
this.classFieldName = classFieldName;
this.analyzer = analyzer;
this.docsWithClassSize = countDocsWithClass();
+ this.query = query;
+ }
+
+ @Override
+ public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
+ train(atomicReader, textFieldName, classFieldName, analyzer, null);
}
private int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
- indexSearcher.search(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))),
+ Query q;
+ if (query != null) {
+ BooleanQuery bq = new BooleanQuery();
+ WildcardQuery wq = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
+ bq.add(wq, BooleanClause.Occur.MUST);
+ bq.add(query, BooleanClause.Occur.MUST);
+ q = bq;
+ } else {
+ q = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
+ }
+ indexSearcher.search(q,
totalHitCountCollector);
docCount = totalHitCountCollector.getTotalHits();
}
@@ -106,7 +124,7 @@ public class SimpleNaiveBayesClassifier
@Override
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
if (atomicReader == null) {
- throw new IOException("You must first call Classifier#train first");
+ throw new IOException("You must first call Classifier#train");
}
double max = 0d;
BytesRef foundClass = new BytesRef();
@@ -161,6 +179,9 @@ public class SimpleNaiveBayesClassifier
BooleanQuery booleanQuery = new BooleanQuery();
booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
+ if (query != null) {
+ booleanQuery.add(query, BooleanClause.Occur.MUST);
+ }
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
indexSearcher.search(booleanQuery, totalHitCountCollector);
return totalHitCountCollector.getTotalHits();
@@ -173,4 +194,4 @@ public class SimpleNaiveBayesClassifier
private int docCount(BytesRef countedClass) throws IOException {
return atomicReader.docFreq(new Term(classFieldName, countedClass));
}
-}
+}
\ No newline at end of file
Modified: lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java Thu Oct 17 08:14:03 2013
@@ -24,12 +24,17 @@ import org.apache.lucene.document.TextFi
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
+import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util._TestUtil;
import org.junit.After;
import org.junit.Before;
+import java.io.IOException;
+import java.util.Random;
+
/**
* Base class for testing {@link Classifier}s
*/
@@ -41,11 +46,13 @@ public abstract class ClassificationTest
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
private RandomIndexWriter indexWriter;
- private String textFieldName;
private Directory dir;
+
+ String textFieldName;
String categoryFieldName;
String booleanFieldName;
+ @Override
@Before
public void setUp() throws Exception {
super.setUp();
@@ -56,6 +63,7 @@ public abstract class ClassificationTest
booleanFieldName = "bool";
}
+ @Override
@After
public void tearDown() throws Exception {
super.tearDown();
@@ -63,85 +71,147 @@ public abstract class ClassificationTest
dir.close();
}
+ protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
+ checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
+ }
- protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception {
- AtomicReader compositeReaderWrapper = null;
+ protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
+ AtomicReader atomicReader = null;
try {
- populateIndex(analyzer);
- compositeReaderWrapper = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
- classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
+ populateSampleIndex(analyzer);
+ atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
+ classifier.train(atomicReader, textFieldName, classFieldName, analyzer, query);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
} finally {
- if (compositeReaderWrapper != null)
- compositeReaderWrapper.close();
+ if (atomicReader != null)
+ atomicReader.close();
+ }
+ }
+
+ protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
+ AtomicReader atomicReader = null;
+ long trainStart = System.currentTimeMillis();
+ long trainEnd = 0l;
+ try {
+ populatePerformanceIndex(analyzer);
+ atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
+ classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
+ trainEnd = System.currentTimeMillis();
+ long trainTime = trainEnd - trainStart;
+ assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
+ } finally {
+ if (atomicReader != null)
+ atomicReader.close();
+ }
+ }
+
+ private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
+ indexWriter.deleteAll();
+ indexWriter.commit();
+
+ FieldType ft = new FieldType(TextField.TYPE_STORED);
+ ft.setStoreTermVectors(true);
+ ft.setStoreTermVectorOffsets(true);
+ ft.setStoreTermVectorPositions(true);
+ int docs = 1000;
+ Random random = random();
+ for (int i = 0; i < docs; i++) {
+ boolean b = random.nextBoolean();
+ Document doc = new Document();
+ doc.add(new Field(textFieldName, createRandomString(random), ft));
+ doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
+ doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
+ indexWriter.addDocument(doc, analyzer);
+ }
+ indexWriter.commit();
+ }
+
+ private String createRandomString(Random random) {
+ StringBuilder builder = new StringBuilder();
+ for (int i = 0; i < 20; i++) {
+ builder.append(_TestUtil.randomSimpleString(random, 5));
+ builder.append(" ");
}
+ return builder.toString();
}
- private void populateIndex(Analyzer analyzer) throws Exception {
+ private void populateSampleIndex(Analyzer analyzer) throws Exception {
+
+ indexWriter.deleteAll();
+ indexWriter.commit();
FieldType ft = new FieldType(TextField.TYPE_STORED);
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
+ String text;
+
Document doc = new Document();
- doc.add(new Field(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
+ text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
- "the Unknown Soldier in Warsaw Tuesday.", ft));
+ "the Unknown Soldier in Warsaw Tuesday.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
- doc.add(new Field(booleanFieldName, "false", ft));
+ doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
- " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", ft));
+ text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
+ " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
- doc.add(new Field(booleanFieldName, "false", ft));
+ doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "And there's a threshold question that he has to answer for the American people and " +
+ text = "And there's a threshold question that he has to answer for the American people and " +
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
- "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", ft));
+ "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
- doc.add(new Field(booleanFieldName, "false", ft));
+ doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
+ text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
- "Albany's School of Criminal Justice.", ft));
+ "Albany's School of Criminal Justice.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
- doc.add(new Field(booleanFieldName, "false", ft));
+ doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
+ text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
- "world through the Internet.", ft));
+ "world through the Internet.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
- doc.add(new Field(booleanFieldName, "true", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " +
- "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", ft));
+ text = "So, about all those experts and analysts who've spent the past year or so saying " +
+ "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
- doc.add(new Field(booleanFieldName, "true", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
- doc.add(new Field(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" +
+ text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
- "generally transfer or store huge volumes of personal data online.", ft));
+ "generally transfer or store huge volumes of personal data online.";
+ doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
- doc.add(new Field(booleanFieldName, "true", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
indexWriter.commit();
}
-}
+}
\ No newline at end of file
Modified: lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java Thu Oct 17 08:14:03 2013
@@ -17,6 +17,8 @@
package org.apache.lucene.classification;
import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.junit.Test;
@@ -27,7 +29,17 @@ public class KNearestNeighborClassifierT
@Test
public void testBasicUsage() throws Exception {
- checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
+ checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
}
-}
+ @Test
+ public void testBasicUsageWithQuery() throws Exception {
+ checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
+ }
+
+ @Test
+ public void testPerformance() throws Exception {
+ checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);
+ }
+
+}
\ No newline at end of file
Modified: lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java Thu Oct 17 08:14:03 2013
@@ -21,11 +21,11 @@ import org.apache.lucene.analysis.MockAn
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.core.KeywordTokenizer;
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
-import org.apache.lucene.analysis.ngram.EdgeNGramTokenizer;
import org.apache.lucene.analysis.reverse.ReverseStringFilter;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
-import org.apache.lucene.util.Version;
import org.junit.Test;
import java.io.Reader;
@@ -39,13 +39,18 @@ public class SimpleNaiveBayesClassifierT
@Test
public void testBasicUsage() throws Exception {
- checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
- checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), categoryFieldName);
+ checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
+ checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
+ }
+
+ @Test
+ public void testBasicUsageWithQuery() throws Exception {
+ checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
}
@Test
public void testNGramUsage() throws Exception {
- checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), categoryFieldName);
+ checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
}
private class NGramAnalyzer extends Analyzer {
@@ -56,4 +61,9 @@ public class SimpleNaiveBayesClassifierT
}
}
-}
+ @Test
+ public void testPerformance() throws Exception {
+ checkPerformance(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
+ }
+
+}
\ No newline at end of file