You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2013/10/17 10:14:03 UTC

svn commit: r1533003 - in /lucene/dev/branches/branch_4x/lucene/classification/src: java/org/apache/lucene/classification/ test/org/apache/lucene/classification/

Author: tommaso
Date: Thu Oct 17 08:14:03 2013
New Revision: 1533003

URL: http://svn.apache.org/r1533003
Log:
LUCENE-5290 - backport LUCENE-5284 and performance tests from trunk

Modified:
    lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
    lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
    lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
    lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
    lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
    lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java

Modified: lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java Thu Oct 17 08:14:03 2013
@@ -18,6 +18,7 @@ package org.apache.lucene.classification
 
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.index.AtomicReader;
+import org.apache.lucene.search.Query;
 
 import java.io.IOException;
 
@@ -47,4 +48,16 @@ public interface Classifier<T> {
   public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
       throws IOException;
 
+  /**
+   * Train the classifier using the underlying Lucene index
+   * @param atomicReader the reader to use to access the Lucene index
+   * @param textFieldName the name of the field used to compare documents
+   * @param classFieldName the name of the field containing the class assigned to documents
+   * @param analyzer the analyzer used to tokenize / filter the unseen text
+   * @param query the query to filter which documents use for training
+   * @throws IOException If there is a low-level I/O error.
+   */
+  public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
+      throws IOException;
+
 }

Modified: lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java Thu Oct 17 08:14:03 2013
@@ -19,6 +19,8 @@ package org.apache.lucene.classification
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.index.AtomicReader;
 import org.apache.lucene.queries.mlt.MoreLikeThis;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreDoc;
@@ -43,6 +45,7 @@ public class KNearestNeighborClassifier 
   private String classFieldName;
   private IndexSearcher indexSearcher;
   private int k;
+  private Query query;
 
   /**
    * Create a {@link Classifier} using kNN algorithm
@@ -59,9 +62,18 @@ public class KNearestNeighborClassifier 
   @Override
   public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
     if (mlt == null) {
-      throw new IOException("You must first call Classifier#train first");
+      throw new IOException("You must first call Classifier#train");
+    }
+    Query q;
+    if (query != null) {
+      Query mltQuery = mlt.like(new StringReader(text), textFieldName);
+      BooleanQuery bq = new BooleanQuery();
+      bq.add(query, BooleanClause.Occur.MUST);
+      bq.add(mltQuery, BooleanClause.Occur.MUST);
+      q = bq;
+    } else {
+      q = mlt.like(new StringReader(text), textFieldName);
     }
-    Query q = mlt.like(new StringReader(text), textFieldName);
     TopDocs topDocs = indexSearcher.search(q, k);
     return selectClassFromNeighbors(topDocs);
   }
@@ -71,13 +83,11 @@ public class KNearestNeighborClassifier 
     Map<BytesRef, Integer> classCounts = new HashMap<BytesRef, Integer>();
     for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
       BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
-      if (cl != null) {
-        Integer count = classCounts.get(cl);
-        if (count != null) {
-          classCounts.put(cl, count + 1);
-        } else {
-          classCounts.put(cl, 1);
-        }
+      Integer count = classCounts.get(cl);
+      if (count != null) {
+        classCounts.put(cl, count + 1);
+      } else {
+        classCounts.put(cl, 1);
       }
     }
     double max = 0;
@@ -98,11 +108,20 @@ public class KNearestNeighborClassifier 
    */
   @Override
   public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
+    train(atomicReader, textFieldName, classFieldName, analyzer, null);
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
     this.textFieldName = textFieldName;
     this.classFieldName = classFieldName;
     mlt = new MoreLikeThis(atomicReader);
     mlt.setAnalyzer(analyzer);
     mlt.setFieldNames(new String[]{textFieldName});
     indexSearcher = new IndexSearcher(atomicReader);
+    this.query = query;
   }
-}
+}
\ No newline at end of file

Modified: lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java Thu Oct 17 08:14:03 2013
@@ -27,6 +27,7 @@ import org.apache.lucene.index.TermsEnum
 import org.apache.lucene.search.BooleanClause;
 import org.apache.lucene.search.BooleanQuery;
 import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
 import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.search.TotalHitCountCollector;
 import org.apache.lucene.search.WildcardQuery;
@@ -50,6 +51,7 @@ public class SimpleNaiveBayesClassifier 
   private int docsWithClassSize;
   private Analyzer analyzer;
   private IndexSearcher indexSearcher;
+  private Query query;
 
   /**
    * Creates a new NaiveBayes classifier.
@@ -63,7 +65,7 @@ public class SimpleNaiveBayesClassifier 
    * {@inheritDoc}
    */
   @Override
-  public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
+  public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
       throws IOException {
     this.atomicReader = atomicReader;
     this.indexSearcher = new IndexSearcher(this.atomicReader);
@@ -71,13 +73,29 @@ public class SimpleNaiveBayesClassifier 
     this.classFieldName = classFieldName;
     this.analyzer = analyzer;
     this.docsWithClassSize = countDocsWithClass();
+    this.query = query;
+  }
+
+  @Override
+  public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
+    train(atomicReader, textFieldName, classFieldName, analyzer, null);
   }
 
   private int countDocsWithClass() throws IOException {
     int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
     if (docCount == -1) { // in case codec doesn't support getDocCount
       TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
-      indexSearcher.search(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))),
+      Query q;
+      if (query != null) {
+        BooleanQuery bq = new BooleanQuery();
+        WildcardQuery wq = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
+        bq.add(wq, BooleanClause.Occur.MUST);
+        bq.add(query, BooleanClause.Occur.MUST);
+        q = bq;
+      } else {
+        q = new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING)));
+      }
+      indexSearcher.search(q,
           totalHitCountCollector);
       docCount = totalHitCountCollector.getTotalHits();
     }
@@ -106,7 +124,7 @@ public class SimpleNaiveBayesClassifier 
   @Override
   public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
     if (atomicReader == null) {
-      throw new IOException("You must first call Classifier#train first");
+      throw new IOException("You must first call Classifier#train");
     }
     double max = 0d;
     BytesRef foundClass = new BytesRef();
@@ -161,6 +179,9 @@ public class SimpleNaiveBayesClassifier 
     BooleanQuery booleanQuery = new BooleanQuery();
     booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
     booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
+    if (query != null) {
+      booleanQuery.add(query, BooleanClause.Occur.MUST);
+    }
     TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
     indexSearcher.search(booleanQuery, totalHitCountCollector);
     return totalHitCountCollector.getTotalHits();
@@ -173,4 +194,4 @@ public class SimpleNaiveBayesClassifier 
   private int docCount(BytesRef countedClass) throws IOException {
     return atomicReader.docFreq(new Term(classFieldName, countedClass));
   }
-}
+}
\ No newline at end of file

Modified: lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java Thu Oct 17 08:14:03 2013
@@ -24,12 +24,17 @@ import org.apache.lucene.document.TextFi
 import org.apache.lucene.index.AtomicReader;
 import org.apache.lucene.index.RandomIndexWriter;
 import org.apache.lucene.index.SlowCompositeReaderWrapper;
+import org.apache.lucene.search.Query;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util._TestUtil;
 import org.junit.After;
 import org.junit.Before;
 
+import java.io.IOException;
+import java.util.Random;
+
 /**
  * Base class for testing {@link Classifier}s
  */
@@ -41,11 +46,13 @@ public abstract class ClassificationTest
   public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
 
   private RandomIndexWriter indexWriter;
-  private String textFieldName;
   private Directory dir;
+
+  String textFieldName;
   String categoryFieldName;
   String booleanFieldName;
 
+  @Override
   @Before
   public void setUp() throws Exception {
     super.setUp();
@@ -56,6 +63,7 @@ public abstract class ClassificationTest
     booleanFieldName = "bool";
   }
 
+  @Override
   @After
   public void tearDown() throws Exception {
     super.tearDown();
@@ -63,85 +71,147 @@ public abstract class ClassificationTest
     dir.close();
   }
 
+  protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
+    checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
+  }
 
-  protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception {
-    AtomicReader compositeReaderWrapper = null;
+  protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
+    AtomicReader atomicReader = null;
     try {
-      populateIndex(analyzer);
-      compositeReaderWrapper = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
-      classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
+      populateSampleIndex(analyzer);
+      atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
+      classifier.train(atomicReader, textFieldName, classFieldName, analyzer, query);
       ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
       assertNotNull(classificationResult.getAssignedClass());
       assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
       assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
     } finally {
-      if (compositeReaderWrapper != null)
-        compositeReaderWrapper.close();
+      if (atomicReader != null)
+        atomicReader.close();
+    }
+  }
+
+  protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
+    AtomicReader atomicReader = null;
+    long trainStart = System.currentTimeMillis();
+    long trainEnd = 0l;
+    try {
+      populatePerformanceIndex(analyzer);
+      atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
+      classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
+      trainEnd = System.currentTimeMillis();
+      long trainTime = trainEnd - trainStart;
+      assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
+    } finally {
+      if (atomicReader != null)
+        atomicReader.close();
+    }
+  }
+
+  private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
+    indexWriter.deleteAll();
+    indexWriter.commit();
+
+    FieldType ft = new FieldType(TextField.TYPE_STORED);
+    ft.setStoreTermVectors(true);
+    ft.setStoreTermVectorOffsets(true);
+    ft.setStoreTermVectorPositions(true);
+    int docs = 1000;
+    Random random = random();
+    for (int i = 0; i < docs; i++) {
+      boolean b = random.nextBoolean();
+      Document doc = new Document();
+      doc.add(new Field(textFieldName, createRandomString(random), ft));
+      doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
+      doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
+      indexWriter.addDocument(doc, analyzer);
+    }
+    indexWriter.commit();
+  }
+
+  private String createRandomString(Random random) {
+    StringBuilder builder = new StringBuilder();
+    for (int i = 0; i < 20; i++) {
+      builder.append(_TestUtil.randomSimpleString(random, 5));
+      builder.append(" ");
     }
+    return builder.toString();
   }
 
-  private void populateIndex(Analyzer analyzer) throws Exception {
+  private void populateSampleIndex(Analyzer analyzer) throws Exception {
+
+    indexWriter.deleteAll();
+    indexWriter.commit();
 
     FieldType ft = new FieldType(TextField.TYPE_STORED);
     ft.setStoreTermVectors(true);
     ft.setStoreTermVectorOffsets(true);
     ft.setStoreTermVectorPositions(true);
 
+    String text;
+
     Document doc = new Document();
-    doc.add(new Field(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
+    text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
         "who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
-        "the Unknown Soldier in Warsaw Tuesday.", ft));
+        "the Unknown Soldier in Warsaw Tuesday.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
-    doc.add(new Field(booleanFieldName, "false", ft));
+    doc.add(new Field(booleanFieldName, "true", ft));
 
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
-        " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", ft));
+    text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
+        " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
-    doc.add(new Field(booleanFieldName, "false", ft));
+    doc.add(new Field(booleanFieldName, "true", ft));
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "And there's a threshold question that he has to answer for the American people and " +
+    text = "And there's a threshold question that he has to answer for the American people and " +
         "that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
-        "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", ft));
+        "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
-    doc.add(new Field(booleanFieldName, "false", ft));
+    doc.add(new Field(booleanFieldName, "true", ft));
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
+    text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
         "keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
-        "Albany's School of Criminal Justice.", ft));
+        "Albany's School of Criminal Justice.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
-    doc.add(new Field(booleanFieldName, "false", ft));
+    doc.add(new Field(booleanFieldName, "true", ft));
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
+    text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
         "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
-        "world through the Internet.", ft));
+        "world through the Internet.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "technology", ft));
-    doc.add(new Field(booleanFieldName, "true", ft));
+    doc.add(new Field(booleanFieldName, "false", ft));
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " +
-        "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", ft));
+    text = "So, about all those experts and analysts who've spent the past year or so saying " +
+        "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "technology", ft));
-    doc.add(new Field(booleanFieldName, "true", ft));
+    doc.add(new Field(booleanFieldName, "false", ft));
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" +
+    text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
         " in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
-        "generally transfer or store huge volumes of personal data online.", ft));
+        "generally transfer or store huge volumes of personal data online.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "technology", ft));
-    doc.add(new Field(booleanFieldName, "true", ft));
+    doc.add(new Field(booleanFieldName, "false", ft));
     indexWriter.addDocument(doc, analyzer);
 
     indexWriter.commit();
   }
-}
+}
\ No newline at end of file

Modified: lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java Thu Oct 17 08:14:03 2013
@@ -17,6 +17,8 @@
 package org.apache.lucene.classification;
 
 import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.util.BytesRef;
 import org.junit.Test;
 
@@ -27,7 +29,17 @@ public class KNearestNeighborClassifierT
 
   @Test
   public void testBasicUsage() throws Exception {
-    checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
+    checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
   }
 
-}
+  @Test
+  public void testBasicUsageWithQuery() throws Exception {
+    checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
+  }
+
+  @Test
+  public void testPerformance() throws Exception {
+    checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);
+  }
+
+}
\ No newline at end of file

Modified: lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java?rev=1533003&r1=1533002&r2=1533003&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java (original)
+++ lucene/dev/branches/branch_4x/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java Thu Oct 17 08:14:03 2013
@@ -21,11 +21,11 @@ import org.apache.lucene.analysis.MockAn
 import org.apache.lucene.analysis.Tokenizer;
 import org.apache.lucene.analysis.core.KeywordTokenizer;
 import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
-import org.apache.lucene.analysis.ngram.EdgeNGramTokenizer;
 import org.apache.lucene.analysis.reverse.ReverseStringFilter;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.LuceneTestCase;
-import org.apache.lucene.util.Version;
 import org.junit.Test;
 
 import java.io.Reader;
@@ -39,13 +39,18 @@ public class SimpleNaiveBayesClassifierT
 
   @Test
   public void testBasicUsage() throws Exception {
-    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
-    checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), categoryFieldName);
+    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
+    checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
+  }
+
+  @Test
+  public void testBasicUsageWithQuery() throws Exception {
+    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
   }
 
   @Test
   public void testNGramUsage() throws Exception {
-    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), categoryFieldName);
+    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
   }
 
   private class NGramAnalyzer extends Analyzer {
@@ -56,4 +61,9 @@ public class SimpleNaiveBayesClassifierT
     }
   }
 
-}
+  @Test
+  public void testPerformance() throws Exception {
+    checkPerformance(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
+  }
+
+}
\ No newline at end of file