You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2013/09/03 10:10:20 UTC

svn commit: r1519590 - in /lucene/dev/trunk/lucene/classification/src: java/org/apache/lucene/classification/ test/org/apache/lucene/classification/ test/org/apache/lucene/classification/utils/

Author: tommaso
Date: Tue Sep  3 08:10:19 2013
New Revision: 1519590

URL: http://svn.apache.org/r1519590
Log:
LUCNE-4818 - added boolean perceptron classifier

Added:
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java   (with props)
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java   (with props)
Modified:
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package.html
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java

Added: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1519590&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java (added)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java Tue Sep  3 08:10:19 2013
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.classification;
+
+import java.io.IOException;
+import java.io.StringReader;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.lucene.index.AtomicReader;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.StorableField;
+import org.apache.lucene.index.StoredDocument;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.IntsRef;
+import org.apache.lucene.util.fst.Builder;
+import org.apache.lucene.util.fst.FST;
+import org.apache.lucene.util.fst.PositiveIntOutputs;
+import org.apache.lucene.util.fst.Util;
+
+/**
+ * A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
+ * <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
+ * weights are calculated using
+ * {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field
+ * and a per document basis and then a corresponding
+ * {@link org.apache.lucene.util.fst.FST} is used for class assignment.
+ * 
+ * @lucene.experimental
+ */
+public class BooleanPerceptronClassifier implements Classifier<Boolean> {
+
+  private Double threshold;
+  private final Integer batchSize;
+  private Terms textTerms;
+  private Analyzer analyzer;
+  private String textFieldName;
+  private FST<Long> fst;
+
+  /**
+   * Create a {@link BooleanPerceptronClassifier}
+   * 
+   * @param threshold
+   *          the binary threshold for perceptron output evaluation
+   */
+  public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
+    this.threshold = threshold;
+    this.batchSize = batchSize;
+  }
+
+  /**
+   * Default constructor, no batch updates of FST, perceptron threshold is
+   * calculated via underlying index metrics during
+   * {@link #train(org.apache.lucene.index.AtomicReader, String, String, org.apache.lucene.analysis.Analyzer)
+   * training}
+   */
+  public BooleanPerceptronClassifier() {
+    batchSize = 1;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public ClassificationResult<Boolean> assignClass(String text)
+      throws IOException {
+    if (textTerms == null) {
+      throw new IOException("You must first call Classifier#train");
+    }
+    Long output = 0l;
+    TokenStream tokenStream = analyzer.tokenStream(textFieldName,
+        new StringReader(text));
+    CharTermAttribute charTermAttribute = tokenStream
+        .addAttribute(CharTermAttribute.class);
+    tokenStream.reset();
+    while (tokenStream.incrementToken()) {
+      String s = charTermAttribute.toString();
+      Long d = Util.get(fst, new BytesRef(s));
+      if (d != null) {
+        output += d;
+      }
+    }
+    tokenStream.end();
+    tokenStream.close();
+
+    return new ClassificationResult<>(output >= threshold, output.doubleValue());
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public void train(AtomicReader atomicReader, String textFieldName,
+      String classFieldName, Analyzer analyzer) throws IOException {
+    this.textTerms = MultiFields.getTerms(atomicReader, textFieldName);
+
+    if (textTerms == null) {
+      throw new IOException(new StringBuilder(
+          "term vectors need to be available for field ").append(textFieldName)
+          .toString());
+    }
+
+    this.analyzer = analyzer;
+    this.textFieldName = textFieldName;
+
+    if (threshold == null || threshold == 0d) {
+      // automatic assign a threshold
+      long sumDocFreq = atomicReader.getSumDocFreq(textFieldName);
+      if (sumDocFreq != -1) {
+        this.threshold = (double) sumDocFreq / 2d;
+      } else {
+        throw new IOException(
+            "threshold cannot be assigned since term vectors for field "
+                + textFieldName + " do not exist");
+      }
+    }
+
+    // TODO : remove this map as soon as we have a writable FST
+    SortedMap<String,Double> weights = new TreeMap<>();
+
+    TermsEnum reuse = textTerms.iterator(null);
+    BytesRef textTerm;
+    while ((textTerm = reuse.next()) != null) {
+      weights.put(textTerm.utf8ToString(), (double) reuse.totalTermFreq());
+    }
+    updateFST(weights);
+
+    IndexSearcher indexSearcher = new IndexSearcher(atomicReader);
+
+    int batchCount = 0;
+
+    // do a *:* search and use stored field values
+    for (ScoreDoc scoreDoc : indexSearcher.search(new MatchAllDocsQuery(),
+        Integer.MAX_VALUE).scoreDocs) {
+      StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
+
+      // assign class to the doc
+      ClassificationResult<Boolean> classificationResult = assignClass(doc
+          .getField(textFieldName).stringValue());
+      Boolean assignedClass = classificationResult.getAssignedClass();
+      
+      // get the expected result
+      StorableField field = doc.getField(classFieldName);
+      
+      Boolean correctClass = Boolean.valueOf(field.stringValue());
+      long modifier = correctClass.compareTo(assignedClass);
+      if (modifier != 0) {
+        reuse = updateWeights(atomicReader, reuse, scoreDoc.doc, assignedClass,
+            weights, modifier, batchCount % batchSize == 0);
+      }
+      batchCount++;
+    }
+    weights.clear(); // free memory while waiting for GC
+  }
+
+  private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse,
+      int docId, Boolean assignedClass, SortedMap<String,Double> weights,
+      double modifier, boolean updateFST) throws IOException {
+    TermsEnum cte = textTerms.iterator(reuse);
+
+    // get the doc term vectors
+    Terms terms = atomicReader.getTermVector(docId, textFieldName);
+
+    if (terms == null) {
+      throw new IOException("term vectors must be stored for field "
+          + textFieldName);
+    }
+
+    TermsEnum termsEnum = terms.iterator(null);
+
+    BytesRef term;
+
+    while ((term = termsEnum.next()) != null) {
+      cte.seekExact(term);
+      if (assignedClass != null) {
+        long termFreqLocal = termsEnum.totalTermFreq();
+        // update weights
+        Long previousValue = Util.get(fst, term);
+        String termString = term.utf8ToString();
+        weights.put(termString, previousValue + modifier * termFreqLocal);
+      }
+    }
+    if (updateFST) {
+      updateFST(weights);
+    }
+    reuse = cte;
+    return reuse;
+  }
+
+  private void updateFST(SortedMap<String,Double> weights) throws IOException {
+    PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
+    Builder<Long> fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
+    BytesRef scratchBytes = new BytesRef();
+    IntsRef scratchInts = new IntsRef();
+    for (Map.Entry<String,Double> entry : weights.entrySet()) {
+      scratchBytes.copyChars(entry.getKey());
+      fstBuilder.add(Util.toIntsRef(scratchBytes, scratchInts), entry
+          .getValue().longValue());
+    }
+    fst = fstBuilder.finish();
+  }
+
+}
\ No newline at end of file

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1519590&r1=1519589&r2=1519590&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java Tue Sep  3 08:10:19 2013
@@ -59,7 +59,7 @@ public class KNearestNeighborClassifier 
   @Override
   public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
     if (mlt == null) {
-      throw new IOException("You must first call Classifier#train first");
+      throw new IOException("You must first call Classifier#train");
     }
     Query q = mlt.like(new StringReader(text), textFieldName);
     TopDocs topDocs = indexSearcher.search(q, k);
@@ -71,13 +71,11 @@ public class KNearestNeighborClassifier 
     Map<BytesRef, Integer> classCounts = new HashMap<BytesRef, Integer>();
     for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
       BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
-      if (cl != null) {
-        Integer count = classCounts.get(cl);
-        if (count != null) {
-          classCounts.put(cl, count + 1);
-        } else {
-          classCounts.put(cl, 1);
-        }
+      Integer count = classCounts.get(cl);
+      if (count != null) {
+        classCounts.put(cl, count + 1);
+      } else {
+        classCounts.put(cl, 1);
       }
     }
     double max = 0;

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1519590&r1=1519589&r2=1519590&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java Tue Sep  3 08:10:19 2013
@@ -102,7 +102,7 @@ public class SimpleNaiveBayesClassifier 
   @Override
   public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
     if (atomicReader == null) {
-      throw new IOException("You must first call Classifier#train first");
+      throw new IOException("You must first call Classifier#train");
     }
     double max = 0d;
     BytesRef foundClass = new BytesRef();

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package.html
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package.html?rev=1519590&r1=1519589&r2=1519590&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package.html (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package.html Tue Sep  3 08:10:19 2013
@@ -17,7 +17,7 @@
 <html>
 <body>
 Uses already seen data (the indexed documents) to classify new documents.
-Currently only contains a (simplistic) Lucene based Naive Bayes classifier 
-and a k-Nearest Neighbor classifier
+Currently only contains a (simplistic) Lucene based Naive Bayes classifier,
+a k-Nearest Neighbor classifier and a Perceptron based classifier
 </body>
 </html>

Added: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java?rev=1519590&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java (added)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java Tue Sep  3 08:10:19 2013
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.classification;
+
+import org.apache.lucene.analysis.MockAnalyzer;
+import org.junit.Test;
+
+/**
+ * Testcase for {@link org.apache.lucene.classification.BooleanPerceptronClassifier}
+ */
+public class BooleanPerceptronClassifierTest extends ClassificationTestBase<Boolean> {
+
+  @Test
+  public void testBasicUsage() throws Exception {
+    checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
+  }
+
+  @Test
+  public void testExplicitThreshold() throws Exception {
+    checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
+  }
+
+  @Test
+  public void testPerformance() throws Exception {
+    checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName);
+  }
+
+}

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1519590&r1=1519589&r2=1519590&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java Tue Sep  3 08:10:19 2013
@@ -27,9 +27,13 @@ import org.apache.lucene.index.SlowCompo
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util._TestUtil;
 import org.junit.After;
 import org.junit.Before;
 
+import java.io.IOException;
+import java.util.Random;
+
 /**
  * Base class for testing {@link Classifier}s
  */
@@ -41,8 +45,9 @@ public abstract class ClassificationTest
   public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
 
   private RandomIndexWriter indexWriter;
-  private String textFieldName;
   private Directory dir;
+
+  String textFieldName;
   String categoryFieldName;
   String booleanFieldName;
 
@@ -66,82 +71,141 @@ public abstract class ClassificationTest
   }
 
 
-  protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception {
-    AtomicReader compositeReaderWrapper = null;
+  protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
+    AtomicReader atomicReader = null;
     try {
-      populateIndex(analyzer);
-      compositeReaderWrapper = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
-      classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer);
+      populateSampleIndex(analyzer);
+      atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
+      classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
       ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
       assertNotNull(classificationResult.getAssignedClass());
       assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
       assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
     } finally {
-      if (compositeReaderWrapper != null)
-        compositeReaderWrapper.close();
+      if (atomicReader != null)
+        atomicReader.close();
+    }
+  }
+
+  protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
+    AtomicReader atomicReader = null;
+    long trainStart = System.currentTimeMillis();
+    long trainEnd = 0l;
+    try {
+      populatePerformanceIndex(analyzer);
+      atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
+      classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
+      trainEnd = System.currentTimeMillis();
+      long trainTime = trainEnd - trainStart;
+      assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
+    } finally {
+      if (atomicReader != null)
+        atomicReader.close();
     }
   }
 
-  private void populateIndex(Analyzer analyzer) throws Exception {
+  private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
+    indexWriter.deleteAll();
+    indexWriter.commit();
 
     FieldType ft = new FieldType(TextField.TYPE_STORED);
     ft.setStoreTermVectors(true);
     ft.setStoreTermVectorOffsets(true);
     ft.setStoreTermVectorPositions(true);
+    int docs = 1000;
+    Random random = random();
+    for (int i = 0; i < docs; i++) {
+      boolean b = random.nextBoolean();
+      Document doc = new Document();
+      doc.add(new Field(textFieldName, createRandomString(random), ft));
+      doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
+      doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
+      indexWriter.addDocument(doc, analyzer);
+    }
+    indexWriter.commit();
+  }
+
+  private String createRandomString(Random random) {
+    StringBuilder builder = new StringBuilder();
+    for (int i = 0; i < 20; i++) {
+      builder.append(_TestUtil.randomSimpleString(random, 5));
+      builder.append(" ");
+    }
+    return builder.toString();
+  }
+
+  private void populateSampleIndex(Analyzer analyzer) throws Exception {
+
+    indexWriter.deleteAll();
+    indexWriter.commit();
+
+    FieldType ft = new FieldType(TextField.TYPE_STORED);
+    ft.setStoreTermVectors(true);
+    ft.setStoreTermVectorOffsets(true);
+    ft.setStoreTermVectorPositions(true);
+
+    String text;
 
     Document doc = new Document();
-    doc.add(new Field(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
+    text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
         "who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
-        "the Unknown Soldier in Warsaw Tuesday.", ft));
+        "the Unknown Soldier in Warsaw Tuesday.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
-    doc.add(new Field(booleanFieldName, "false", ft));
+    doc.add(new Field(booleanFieldName, "true", ft));
 
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
-        " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", ft));
+    text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
+        " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
-    doc.add(new Field(booleanFieldName, "false", ft));
+    doc.add(new Field(booleanFieldName, "true", ft));
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "And there's a threshold question that he has to answer for the American people and " +
+    text = "And there's a threshold question that he has to answer for the American people and " +
         "that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
-        "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", ft));
+        "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
-    doc.add(new Field(booleanFieldName, "false", ft));
+    doc.add(new Field(booleanFieldName, "true", ft));
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
+    text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
         "keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
-        "Albany's School of Criminal Justice.", ft));
+        "Albany's School of Criminal Justice.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
-    doc.add(new Field(booleanFieldName, "false", ft));
+    doc.add(new Field(booleanFieldName, "true", ft));
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
+    text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
         "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
-        "world through the Internet.", ft));
+        "world through the Internet.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "technology", ft));
-    doc.add(new Field(booleanFieldName, "true", ft));
+    doc.add(new Field(booleanFieldName, "false", ft));
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " +
-        "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", ft));
+    text = "So, about all those experts and analysts who've spent the past year or so saying " +
+        "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "technology", ft));
-    doc.add(new Field(booleanFieldName, "true", ft));
+    doc.add(new Field(booleanFieldName, "false", ft));
     indexWriter.addDocument(doc, analyzer);
 
     doc = new Document();
-    doc.add(new Field(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" +
+    text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
         " in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
-        "generally transfer or store huge volumes of personal data online.", ft));
+        "generally transfer or store huge volumes of personal data online.";
+    doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "technology", ft));
-    doc.add(new Field(booleanFieldName, "true", ft));
+    doc.add(new Field(booleanFieldName, "false", ft));
     indexWriter.addDocument(doc, analyzer);
 
     indexWriter.commit();

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java?rev=1519590&r1=1519589&r2=1519590&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java Tue Sep  3 08:10:19 2013
@@ -27,7 +27,12 @@ public class KNearestNeighborClassifierT
 
   @Test
   public void testBasicUsage() throws Exception {
-    checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
+    checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
+  }
+
+  @Test
+  public void testPerformance() throws Exception {
+    checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);
   }
 
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java?rev=1519590&r1=1519589&r2=1519590&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java Tue Sep  3 08:10:19 2013
@@ -21,11 +21,9 @@ import org.apache.lucene.analysis.MockAn
 import org.apache.lucene.analysis.Tokenizer;
 import org.apache.lucene.analysis.core.KeywordTokenizer;
 import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
-import org.apache.lucene.analysis.ngram.EdgeNGramTokenizer;
 import org.apache.lucene.analysis.reverse.ReverseStringFilter;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.LuceneTestCase;
-import org.apache.lucene.util.Version;
 import org.junit.Test;
 
 import java.io.Reader;
@@ -39,13 +37,13 @@ public class SimpleNaiveBayesClassifierT
 
   @Test
   public void testBasicUsage() throws Exception {
-    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), categoryFieldName);
-    checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), categoryFieldName);
+    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
+    checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
   }
 
   @Test
   public void testNGramUsage() throws Exception {
-    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), categoryFieldName);
+    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
   }
 
   private class NGramAnalyzer extends Analyzer {
@@ -56,4 +54,9 @@ public class SimpleNaiveBayesClassifierT
     }
   }
 
+  @Test
+  public void testPerformance() throws Exception {
+    checkPerformance(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
+  }
+
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java?rev=1519590&r1=1519589&r2=1519590&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java Tue Sep  3 08:10:19 2013
@@ -131,9 +131,15 @@ public class DataSplitterTest extends Lu
       closeQuietly(testReader);
       closeQuietly(cvReader);
     } finally {
-      trainingIndex.close();
-      testIndex.close();
-      crossValidationIndex.close();
+      if (trainingIndex != null) {
+        trainingIndex.close();
+      }
+      if (testIndex != null) {
+        testIndex.close();
+      }
+      if (crossValidationIndex != null) {
+        crossValidationIndex.close();
+      }
     }
   }