You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@lucene.apache.org by Tommaso Teofili <to...@gmail.com> on 2015/05/03 08:55:50 UTC

Re: svn commit: r1676997 - in /lucene/dev/trunk/lucene/classification/src: java/org/apache/lucene/classification/ java/org/apache/lucene/classification/utils/ test/org/apache/lucene/classification/

sorry for the inconvenience, refactored APIs changed constructor but forgot
to update the related javadocs, it should be fixed now.

Tommaso

2015-04-30 18:28 GMT+02:00 Ryan Ernst <ry...@iernst.net>:

> Did you mean to remove javadocs on BooleanPerceptionClassifier? This
> breaks precommit...
>
> On Thu, Apr 30, 2015 at 7:12 AM, <to...@apache.org> wrote:
>
>> Author: tommaso
>> Date: Thu Apr 30 14:12:03 2015
>> New Revision: 1676997
>>
>> URL: http://svn.apache.org/r1676997
>> Log:
>> LUCENE-6045 - refactor Classifier API to work better with multithreading
>>
>> Modified:
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
>>
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
>>
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
>>
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
>>
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
>>
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
>> Thu Apr 30 14:12:03 2015
>> @@ -58,76 +58,14 @@ import org.apache.lucene.util.fst.Util;
>>   */
>>  public class BooleanPerceptronClassifier implements Classifier<Boolean> {
>>
>> -  private Double threshold;
>> -  private final Integer batchSize;
>> -  private Terms textTerms;
>> -  private Analyzer analyzer;
>> -  private String textFieldName;
>> +  private final Double threshold;
>> +  private final Terms textTerms;
>> +  private final Analyzer analyzer;
>> +  private final String textFieldName;
>>    private FST<Long> fst;
>>
>> -  /**
>> -   * Create a {@link BooleanPerceptronClassifier}
>> -   *
>> -   * @param threshold the binary threshold for perceptron output
>> evaluation
>> -   */
>> -  public BooleanPerceptronClassifier(Double threshold, Integer
>> batchSize) {
>> -    this.threshold = threshold;
>> -    this.batchSize = batchSize;
>> -  }
>> -
>> -  /**
>> -   * Default constructor, no batch updates of FST, perceptron threshold
>> is
>> -   * calculated via underlying index metrics during
>> -   * {@link #train(org.apache.lucene.index.LeafReader, String, String,
>> org.apache.lucene.analysis.Analyzer)
>> -   * training}
>> -   */
>> -  public BooleanPerceptronClassifier() {
>> -    batchSize = 1;
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public ClassificationResult<Boolean> assignClass(String text)
>> -      throws IOException {
>> -    if (textTerms == null) {
>> -      throw new IOException("You must first call Classifier#train");
>> -    }
>> -    Long output = 0l;
>> -    try (TokenStream tokenStream = analyzer.tokenStream(textFieldName,
>> text)) {
>> -      CharTermAttribute charTermAttribute = tokenStream
>> -          .addAttribute(CharTermAttribute.class);
>> -      tokenStream.reset();
>> -      while (tokenStream.incrementToken()) {
>> -        String s = charTermAttribute.toString();
>> -        Long d = Util.get(fst, new BytesRef(s));
>> -        if (d != null) {
>> -          output += d;
>> -        }
>> -      }
>> -      tokenStream.end();
>> -    }
>> -
>> -    double score = 1 - Math.exp(-1 * Math.abs(threshold -
>> output.doubleValue()) / threshold);
>> -    return new ClassificationResult<>(output >= threshold, score);
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String textFieldName,
>> -                    String classFieldName, Analyzer analyzer) throws
>> IOException {
>> -    train(leafReader, textFieldName, classFieldName, analyzer, null);
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String textFieldName,
>> -                    String classFieldName, Analyzer analyzer, Query
>> query) throws IOException {
>> +  public BooleanPerceptronClassifier(LeafReader leafReader, String
>> textFieldName, String classFieldName, Analyzer analyzer,
>> +                                     Query query, Integer batchSize,
>> Double threshold) throws IOException {
>>      this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
>>
>>      if (textTerms == null) {
>> @@ -144,9 +82,11 @@ public class BooleanPerceptronClassifier
>>          this.threshold = (double) sumDocFreq / 2d;
>>        } else {
>>          throw new IOException(
>> -            "threshold cannot be assigned since term vectors for field "
>> -                + textFieldName + " do not exist");
>> +                "threshold cannot be assigned since term vectors for
>> field "
>> +                        + textFieldName + " do not exist");
>>        }
>> +    } else {
>> +      this.threshold = threshold;
>>      }
>>
>>      // TODO : remove this map as soon as we have a writable FST
>> @@ -170,7 +110,7 @@ public class BooleanPerceptronClassifier
>>      }
>>      // run the search and use stored field values
>>      for (ScoreDoc scoreDoc : indexSearcher.search(q,
>> -        Integer.MAX_VALUE).scoreDocs) {
>> +            Integer.MAX_VALUE).scoreDocs) {
>>        StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
>>
>>        StorableField textField = doc.getField(textFieldName);
>> @@ -187,7 +127,7 @@ public class BooleanPerceptronClassifier
>>          long modifier = correctClass.compareTo(assignedClass);
>>          if (modifier != 0) {
>>            updateWeights(leafReader, scoreDoc.doc, assignedClass,
>> -                weights, modifier, batchCount % batchSize == 0);
>> +                  weights, modifier, batchCount % batchSize == 0);
>>          }
>>          batchCount++;
>>        }
>> @@ -195,11 +135,6 @@ public class BooleanPerceptronClassifier
>>      weights.clear(); // free memory while waiting for GC
>>    }
>>
>> -  @Override
>> -  public void train(LeafReader leafReader, String[] textFieldNames,
>> String classFieldName, Analyzer analyzer, Query query) throws IOException {
>> -    throw new IOException("training with multiple fields not supported
>> by boolean perceptron classifier");
>> -  }
>> -
>>    private void updateWeights(LeafReader leafReader,
>>                               int docId, Boolean assignedClass,
>> SortedMap<String, Double> weights,
>>                               double modifier, boolean updateFST) throws
>> IOException {
>> @@ -210,7 +145,7 @@ public class BooleanPerceptronClassifier
>>
>>      if (terms == null) {
>>        throw new IOException("term vectors must be stored for field "
>> -          + textFieldName);
>> +              + textFieldName);
>>      }
>>
>>      TermsEnum termsEnum = terms.iterator();
>> @@ -240,17 +175,46 @@ public class BooleanPerceptronClassifier
>>      for (Map.Entry<String, Double> entry : weights.entrySet()) {
>>        scratchBytes.copyChars(entry.getKey());
>>        fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts),
>> entry
>> -          .getValue().longValue());
>> +              .getValue().longValue());
>>      }
>>      fst = fstBuilder.finish();
>>    }
>>
>> +
>> +  /**
>> +   * {@inheritDoc}
>> +   */
>> +  @Override
>> +  public ClassificationResult<Boolean> assignClass(String text)
>> +          throws IOException {
>> +    if (textTerms == null) {
>> +      throw new IOException("You must first call Classifier#train");
>> +    }
>> +    Long output = 0l;
>> +    try (TokenStream tokenStream = analyzer.tokenStream(textFieldName,
>> text)) {
>> +      CharTermAttribute charTermAttribute = tokenStream
>> +              .addAttribute(CharTermAttribute.class);
>> +      tokenStream.reset();
>> +      while (tokenStream.incrementToken()) {
>> +        String s = charTermAttribute.toString();
>> +        Long d = Util.get(fst, new BytesRef(s));
>> +        if (d != null) {
>> +          output += d;
>> +        }
>> +      }
>> +      tokenStream.end();
>> +    }
>> +
>> +    double score = 1 - Math.exp(-1 * Math.abs(threshold -
>> output.doubleValue()) / threshold);
>> +    return new ClassificationResult<>(output >= threshold, score);
>> +  }
>> +
>>    /**
>>     * {@inheritDoc}
>>     */
>>    @Override
>>    public List<ClassificationResult<Boolean>> getClasses(String text)
>> -      throws IOException {
>> +          throws IOException {
>>      throw new RuntimeException("not implemented");
>>    }
>>
>> @@ -259,7 +223,7 @@ public class BooleanPerceptronClassifier
>>     */
>>    @Override
>>    public List<ClassificationResult<Boolean>> getClasses(String text, int
>> max)
>> -      throws IOException {
>> +          throws IOException {
>>      throw new RuntimeException("not implemented");
>>    }
>>
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
>> Thu Apr 30 14:12:03 2015
>> @@ -49,50 +49,30 @@ import org.apache.lucene.util.BytesRef;
>>   */
>>  public class CachingNaiveBayesClassifier extends
>> SimpleNaiveBayesClassifier {
>>    //for caching classes this will be the classification class list
>> -  private ArrayList<BytesRef> cclasses = new ArrayList<>();
>> +  private final ArrayList<BytesRef> cclasses = new ArrayList<>();
>>    // it's a term-inmap style map, where the inmap contains class-hit
>> pairs to the
>>    // upper term
>> -  private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new
>> HashMap<>();
>> +  private final Map<String, Map<BytesRef, Integer>> termCClassHitCache =
>> new HashMap<>();
>>    // the term frequency in classes
>> -  private Map<BytesRef, Double> classTermFreq = new HashMap<>();
>> +  private final Map<BytesRef, Double> classTermFreq = new HashMap<>();
>>    private boolean justCachedTerms;
>>    private int docsWithClassSize;
>>
>>    /**
>> -   * Creates a new NaiveBayes classifier with inside caching. Note that
>> you must
>> -   * call {@link #train(org.apache.lucene.index.LeafReader, String,
>> String, Analyzer) train()} before
>> -   * you can classify any documents. If you want less memory usage you
>> could
>> +   * Creates a new NaiveBayes classifier with inside caching. If you
>> want less memory usage you could
>>     * call {@link #reInitCache(int, boolean) reInitCache()}.
>>     */
>> -  public CachingNaiveBayesClassifier() {
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer) throws IOException {
>> -    train(leafReader, textFieldName, classFieldName, analyzer, null);
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer, Query query) throws IOException {
>> -    train(leafReader, new String[]{textFieldName}, classFieldName,
>> analyzer, query);
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String[] textFieldNames,
>> String classFieldName, Analyzer analyzer, Query query) throws IOException {
>> -    super.train(leafReader, textFieldNames, classFieldName, analyzer,
>> query);
>> +  public CachingNaiveBayesClassifier(LeafReader leafReader, Analyzer
>> analyzer, Query query, String classFieldName, String... textFieldNames) {
>> +    super(leafReader, analyzer, query, classFieldName, textFieldNames);
>>      // building the cache
>> -    reInitCache(0, true);
>> +    try {
>> +      reInitCache(0, true);
>> +    } catch (IOException e) {
>> +      throw new RuntimeException(e);
>> +    }
>>    }
>>
>> +
>>    private List<ClassificationResult<BytesRef>>
>> assignClassNormalizedList(String inputDocument) throws IOException {
>>      if (leafReader == null) {
>>        throw new IOException("You must first call Classifier#train");
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
>> Thu Apr 30 14:12:03 2015
>> @@ -18,17 +18,19 @@ package org.apache.lucene.classification
>>
>>  /**
>>   * The result of a call to {@link Classifier#assignClass(String)}
>> holding an assigned class of type <code>T</code> and a score.
>> + *
>>   * @lucene.experimental
>>   */
>> -public class ClassificationResult<T> implements
>> Comparable<ClassificationResult<T>>{
>> +public class ClassificationResult<T> implements
>> Comparable<ClassificationResult<T>> {
>>
>>    private final T assignedClass;
>>    private double score;
>>
>>    /**
>>     * Constructor
>> +   *
>>     * @param assignedClass the class <code>T</code> assigned by a {@link
>> Classifier}
>> -   * @param score the score for the assignedClass as a
>> <code>double</code>
>> +   * @param score         the score for the assignedClass as a
>> <code>double</code>
>>     */
>>    public ClassificationResult(T assignedClass, double score) {
>>      this.assignedClass = assignedClass;
>> @@ -37,6 +39,7 @@ public class ClassificationResult<T> imp
>>
>>    /**
>>     * retrieve the result class
>> +   *
>>     * @return a <code>T</code> representing an assigned class
>>     */
>>    public T getAssignedClass() {
>> @@ -45,14 +48,16 @@ public class ClassificationResult<T> imp
>>
>>    /**
>>     * retrieve the result score
>> +   *
>>     * @return a <code>double</code> representing a result score
>>     */
>>    public double getScore() {
>>      return score;
>>    }
>> -
>> +
>>    /**
>>     * set the score value
>> +   *
>>     * @param score the score for the assignedClass as a
>> <code>double</code>
>>     */
>>    public void setScore(double score) {
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
>> Thu Apr 30 14:12:03 2015
>> @@ -22,7 +22,6 @@ import java.util.List;
>>  import org.apache.lucene.analysis.Analyzer;
>>  import org.apache.lucene.index.LeafReader;
>>  import org.apache.lucene.search.Query;
>> -import org.apache.lucene.util.BytesRef;
>>
>>  /**
>>   * A classifier, see <code>
>> http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which
>> assign classes of type
>> @@ -39,7 +38,7 @@ public interface Classifier<T> {
>>     * @return a {@link ClassificationResult} holding assigned class of
>> type <code>T</code> and score
>>     * @throws IOException If there is a low-level I/O error.
>>     */
>> -  public ClassificationResult<T> assignClass(String text) throws
>> IOException;
>> +  ClassificationResult<T> assignClass(String text) throws IOException;
>>
>>    /**
>>     * Get all the classes (sorted by score, descending) assigned to the
>> given text String.
>> @@ -48,7 +47,7 @@ public interface Classifier<T> {
>>     * @return the whole list of {@link ClassificationResult}, the classes
>> and scores. Returns <code>null</code> if the classifier can't make lists.
>>     * @throws IOException If there is a low-level I/O error.
>>     */
>> -  public List<ClassificationResult<T>> getClasses(String text) throws
>> IOException;
>> +  List<ClassificationResult<T>> getClasses(String text) throws
>> IOException;
>>
>>    /**
>>     * Get the first <code>max</code> classes (sorted by score,
>> descending) assigned to the given text String.
>> @@ -58,44 +57,6 @@ public interface Classifier<T> {
>>     * @return the whole list of {@link ClassificationResult}, the classes
>> and scores. Cut for "max" number of elements. Returns <code>null</code> if
>> the classifier can't make lists.
>>     * @throws IOException If there is a low-level I/O error.
>>     */
>> -  public List<ClassificationResult<T>> getClasses(String text, int max)
>> throws IOException;
>> -
>> -  /**
>> -   * Train the classifier using the underlying Lucene index
>> -   *
>> -   * @param leafReader   the reader to use to access the Lucene index
>> -   * @param textFieldName  the name of the field used to compare
>> documents
>> -   * @param classFieldName the name of the field containing the class
>> assigned to documents
>> -   * @param analyzer       the analyzer used to tokenize / filter the
>> unseen text
>> -   * @throws IOException If there is a low-level I/O error.
>> -   */
>> -  public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer)
>> -      throws IOException;
>> -
>> -  /**
>> -   * Train the classifier using the underlying Lucene index
>> -   *
>> -   * @param leafReader   the reader to use to access the Lucene index
>> -   * @param textFieldName  the name of the field used to compare
>> documents
>> -   * @param classFieldName the name of the field containing the class
>> assigned to documents
>> -   * @param analyzer       the analyzer used to tokenize / filter the
>> unseen text
>> -   * @param query          the query to filter which documents use for
>> training
>> -   * @throws IOException If there is a low-level I/O error.
>> -   */
>> -  public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer, Query query)
>> -      throws IOException;
>> -
>> -  /**
>> -   * Train the classifier using the underlying Lucene index
>> -   *
>> -   * @param leafReader   the reader to use to access the Lucene index
>> -   * @param textFieldNames the names of the fields to be used to compare
>> documents
>> -   * @param classFieldName the name of the field containing the class
>> assigned to documents
>> -   * @param analyzer       the analyzer used to tokenize / filter the
>> unseen text
>> -   * @param query          the query to filter which documents use for
>> training
>> -   * @throws IOException If there is a low-level I/O error.
>> -   */
>> -  public void train(LeafReader leafReader, String[] textFieldNames,
>> String classFieldName, Analyzer analyzer, Query query)
>> -      throws IOException;
>> +  List<ClassificationResult<T>> getClasses(String text, int max) throws
>> IOException;
>>
>>  }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
>> Thu Apr 30 14:12:03 2015
>> @@ -26,6 +26,7 @@ import java.util.Map;
>>
>>  import org.apache.lucene.analysis.Analyzer;
>>  import org.apache.lucene.index.LeafReader;
>> +import org.apache.lucene.index.StorableField;
>>  import org.apache.lucene.index.Term;
>>  import org.apache.lucene.queries.mlt.MoreLikeThis;
>>  import org.apache.lucene.search.BooleanClause;
>> @@ -45,37 +46,31 @@ import org.apache.lucene.util.BytesRef;
>>   */
>>  public class KNearestNeighborClassifier implements Classifier<BytesRef> {
>>
>> -  private MoreLikeThis mlt;
>> -  private String[] textFieldNames;
>> -  private String classFieldName;
>> -  private IndexSearcher indexSearcher;
>> +  private final MoreLikeThis mlt;
>> +  private final String[] textFieldNames;
>> +  private final String classFieldName;
>> +  private final IndexSearcher indexSearcher;
>>    private final int k;
>> -  private Query query;
>> +  private final Query query;
>>
>> -  private int minDocsFreq;
>> -  private int minTermFreq;
>> -
>> -  /**
>> -   * Create a {@link Classifier} using kNN algorithm
>> -   *
>> -   * @param k the number of neighbors to analyze as an <code>int</code>
>> -   */
>> -  public KNearestNeighborClassifier(int k) {
>> +  public KNearestNeighborClassifier(LeafReader leafReader, Analyzer
>> analyzer, Query query, int k, int minDocsFreq,
>> +                                    int minTermFreq, String
>> classFieldName, String... textFieldNames) {
>> +    this.textFieldNames = textFieldNames;
>> +    this.classFieldName = classFieldName;
>> +    this.mlt = new MoreLikeThis(leafReader);
>> +    this.mlt.setAnalyzer(analyzer);
>> +    this.mlt.setFieldNames(textFieldNames);
>> +    this.indexSearcher = new IndexSearcher(leafReader);
>> +    if (minDocsFreq > 0) {
>> +      mlt.setMinDocFreq(minDocsFreq);
>> +    }
>> +    if (minTermFreq > 0) {
>> +      mlt.setMinTermFreq(minTermFreq);
>> +    }
>> +    this.query = query;
>>      this.k = k;
>>    }
>>
>> -  /**
>> -   * Create a {@link Classifier} using kNN algorithm
>> -   *
>> -   * @param k           the number of neighbors to analyze as an
>> <code>int</code>
>> -   * @param minDocsFreq the minimum number of docs frequency for MLT to
>> be set with {@link MoreLikeThis#setMinDocFreq(int)}
>> -   * @param minTermFreq the minimum number of term frequency for MLT to
>> be set with {@link MoreLikeThis#setMinTermFreq(int)}
>> -   */
>> -  public KNearestNeighborClassifier(int k, int minDocsFreq, int
>> minTermFreq) {
>> -    this.k = k;
>> -    this.minDocsFreq = minDocsFreq;
>> -    this.minTermFreq = minTermFreq;
>> -  }
>>
>>    /**
>>     * {@inheritDoc}
>> @@ -136,12 +131,15 @@ public class KNearestNeighborClassifier
>>    private List<ClassificationResult<BytesRef>>
>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
>>      Map<BytesRef, Integer> classCounts = new HashMap<>();
>>      for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
>> -      BytesRef cl = new
>> BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
>> -      Integer count = classCounts.get(cl);
>> -      if (count != null) {
>> -        classCounts.put(cl, count + 1);
>> -      } else {
>> -        classCounts.put(cl, 1);
>> +      StorableField storableField =
>> indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
>> +      if (storableField != null) {
>> +        BytesRef cl = new BytesRef(storableField.stringValue());
>> +        Integer count = classCounts.get(cl);
>> +        if (count != null) {
>> +          classCounts.put(cl, count + 1);
>> +        } else {
>> +          classCounts.put(cl, 1);
>> +        }
>>        }
>>      }
>>      List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
>> @@ -161,39 +159,4 @@ public class KNearestNeighborClassifier
>>      return returnList;
>>    }
>>
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer) throws IOException {
>> -    train(leafReader, textFieldName, classFieldName, analyzer, null);
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer, Query query) throws IOException {
>> -    train(leafReader, new String[]{textFieldName}, classFieldName,
>> analyzer, query);
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String[] textFieldNames,
>> String classFieldName, Analyzer analyzer, Query query) throws IOException {
>> -    this.textFieldNames = textFieldNames;
>> -    this.classFieldName = classFieldName;
>> -    mlt = new MoreLikeThis(leafReader);
>> -    mlt.setAnalyzer(analyzer);
>> -    mlt.setFieldNames(textFieldNames);
>> -    indexSearcher = new IndexSearcher(leafReader);
>> -    if (minDocsFreq > 0) {
>> -      mlt.setMinDocFreq(minDocsFreq);
>> -    }
>> -    if (minTermFreq > 0) {
>> -      mlt.setMinTermFreq(minTermFreq);
>> -    }
>> -    this.query = query;
>> -  }
>>  }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
>> Thu Apr 30 14:12:03 2015
>> @@ -51,64 +51,38 @@ public class SimpleNaiveBayesClassifier
>>     * {@link org.apache.lucene.index.LeafReader} used to access the
>> {@link org.apache.lucene.classification.Classifier}'s
>>     * index
>>     */
>> -  protected LeafReader leafReader;
>> +  protected final LeafReader leafReader;
>>
>>    /**
>>     * names of the fields to be used as input text
>>     */
>> -  protected String[] textFieldNames;
>> +  protected final String[] textFieldNames;
>>
>>    /**
>>     * name of the field to be used as a class / category output
>>     */
>> -  protected String classFieldName;
>> +  protected final String classFieldName;
>>
>>    /**
>>     * {@link org.apache.lucene.analysis.Analyzer} to be used for
>> tokenizing unseen input text
>>     */
>> -  protected Analyzer analyzer;
>> +  protected final Analyzer analyzer;
>>
>>    /**
>>     * {@link org.apache.lucene.search.IndexSearcher} to run searches on
>> the index for retrieving frequencies
>>     */
>> -  protected IndexSearcher indexSearcher;
>> +  protected final IndexSearcher indexSearcher;
>>
>>    /**
>>     * {@link org.apache.lucene.search.Query} used to eventually filter
>> the document set to be used to classify
>>     */
>> -  protected Query query;
>> +  protected final Query query;
>>
>>    /**
>>     * Creates a new NaiveBayes classifier.
>> -   * Note that you must call {@link
>> #train(org.apache.lucene.index.LeafReader, String, String, Analyzer)
>> train()} before you can
>>     * classify any documents.
>>     */
>> -  public SimpleNaiveBayesClassifier() {
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer) throws IOException {
>> -    train(leafReader, textFieldName, classFieldName, analyzer, null);
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer, Query query)
>> -      throws IOException {
>> -    train(leafReader, new String[]{textFieldName}, classFieldName,
>> analyzer, query);
>> -  }
>> -
>> -  /**
>> -   * {@inheritDoc}
>> -   */
>> -  @Override
>> -  public void train(LeafReader leafReader, String[] textFieldNames,
>> String classFieldName, Analyzer analyzer, Query query)
>> -      throws IOException {
>> +  public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer
>> analyzer, Query query, String classFieldName, String... textFieldNames) {
>>      this.leafReader = leafReader;
>>      this.indexSearcher = new IndexSearcher(this.leafReader);
>>      this.textFieldNames = textFieldNames;
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
>> Thu Apr 30 14:12:03 2015
>> @@ -18,7 +18,7 @@
>>  /**
>>   * Uses already seen data (the indexed documents) to classify new
>> documents.
>>   * <p>
>> - * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
>> + * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
>>   * Neighbor classifier and a Perceptron based classifier.
>>   */
>>  package org.apache.lucene.classification;
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
>> Thu Apr 30 14:12:03 2015
>> @@ -33,7 +33,8 @@ public class DocToDoubleVectorUtils {
>>
>>    /**
>>     * create a sparse <code>Double</code> vector given doc and field term
>> vectors using local frequency of the terms in the doc
>> -   * @param docTerms term vectors for a given document
>> +   *
>> +   * @param docTerms   term vectors for a given document
>>     * @param fieldTerms field term vectors
>>     * @return a sparse vector of <code>Double</code>s as an array
>>     * @throws IOException in case accessing the underlying index fails
>> @@ -54,8 +55,7 @@ public class DocToDoubleVectorUtils {
>>          if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) {
>>            long termFreqLocal = docTermsEnum.totalTermFreq(); // the
>> total number of occurrences of this term in the given document
>>            freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
>> -        }
>> -        else {
>> +        } else {
>>            freqVector[i] = 0d;
>>          }
>>          i++;
>> @@ -66,6 +66,7 @@ public class DocToDoubleVectorUtils {
>>
>>    /**
>>     * create a dense <code>Double</code> vector given doc and field term
>> vectors using local frequency of the terms in the doc
>> +   *
>>     * @param docTerms term vectors for a given document
>>     * @return a dense vector of <code>Double</code>s as an array
>>     * @throws IOException in case accessing the underlying index fails
>> @@ -73,16 +74,16 @@ public class DocToDoubleVectorUtils {
>>    public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms)
>> throws IOException {
>>      Double[] freqVector = null;
>>      if (docTerms != null) {
>> -        freqVector = new Double[(int) docTerms.size()];
>> -        int i = 0;
>> -        TermsEnum docTermsEnum = docTerms.iterator();
>> +      freqVector = new Double[(int) docTerms.size()];
>> +      int i = 0;
>> +      TermsEnum docTermsEnum = docTerms.iterator();
>>
>> -        while (docTermsEnum.next() != null) {
>> -            long termFreqLocal = docTermsEnum.totalTermFreq(); // the
>> total number of occurrences of this term in the given document
>> -            freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
>> -            i++;
>> -        }
>> +      while (docTermsEnum.next() != null) {
>> +        long termFreqLocal = docTermsEnum.totalTermFreq(); // the total
>> number of occurrences of this term in the given document
>> +        freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
>> +        i++;
>> +      }
>>      }
>>      return freqVector;
>> -}
>> +  }
>>  }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
>> Thu Apr 30 14:12:03 2015
>> @@ -17,6 +17,8 @@
>>  package org.apache.lucene.classification;
>>
>>  import org.apache.lucene.analysis.MockAnalyzer;
>> +import org.apache.lucene.index.LeafReader;
>> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>>  import org.apache.lucene.index.Term;
>>  import org.apache.lucene.search.TermQuery;
>>  import org.junit.Test;
>> @@ -28,22 +30,45 @@ public class BooleanPerceptronClassifier
>>
>>    @Test
>>    public void testBasicUsage() throws Exception {
>> -    checkCorrectClassification(new BooleanPerceptronClassifier(),
>> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName,
>> booleanFieldName);
>> +    LeafReader leafReader = null;
>> +    try {
>> +      MockAnalyzer analyzer = new MockAnalyzer(random());
>> +      leafReader = populateSampleIndex(analyzer);
>> +      checkCorrectClassification(new
>> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName,
>> analyzer, null, 1, null), TECHNOLOGY_INPUT, false);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>    @Test
>>    public void testExplicitThreshold() throws Exception {
>> -    checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1),
>> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName,
>> booleanFieldName);
>> +    LeafReader leafReader = null;
>> +    try {
>> +      MockAnalyzer analyzer = new MockAnalyzer(random());
>> +      leafReader = populateSampleIndex(analyzer);
>> +      checkCorrectClassification(new
>> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName,
>> analyzer, null, 1, 100d), TECHNOLOGY_INPUT, false);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>    @Test
>>    public void testBasicUsageWithQuery() throws Exception {
>> -    checkCorrectClassification(new BooleanPerceptronClassifier(),
>> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName,
>> booleanFieldName, new TermQuery(new Term(textFieldName, "it")));
>> -  }
>> -
>> -  @Test
>> -  public void testPerformance() throws Exception {
>> -    checkPerformance(new BooleanPerceptronClassifier(), new
>> MockAnalyzer(random()), booleanFieldName);
>> +    TermQuery query = new TermQuery(new Term(textFieldName, "it"));
>> +    LeafReader leafReader = null;
>> +    try {
>> +      MockAnalyzer analyzer = new MockAnalyzer(random());
>> +      leafReader = populateSampleIndex(analyzer);
>> +      checkCorrectClassification(new
>> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName,
>> analyzer, query, 1, null), TECHNOLOGY_INPUT, false);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>  }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
>> Thu Apr 30 14:12:03 2015
>> @@ -23,6 +23,8 @@ import org.apache.lucene.analysis.Tokeni
>>  import org.apache.lucene.analysis.core.KeywordTokenizer;
>>  import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
>>  import org.apache.lucene.analysis.reverse.ReverseStringFilter;
>> +import org.apache.lucene.index.LeafReader;
>> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>>  import org.apache.lucene.index.Term;
>>  import org.apache.lucene.search.TermQuery;
>>  import org.apache.lucene.util.BytesRef;
>> @@ -35,18 +37,46 @@ public class CachingNaiveBayesClassifier
>>
>>    @Test
>>    public void testBasicUsage() throws Exception {
>> -    checkCorrectClassification(new CachingNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName);
>> -    checkCorrectClassification(new CachingNaiveBayesClassifier(),
>> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName,
>> categoryFieldName);
>> +    LeafReader leafReader = null;
>> +    try {
>> +      MockAnalyzer analyzer = new MockAnalyzer(random());
>> +      leafReader = populateSampleIndex(analyzer);
>> +      checkCorrectClassification(new
>> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> +      checkCorrectClassification(new
>> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), POLITICS_INPUT, POLITICS_RESULT);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>    @Test
>>    public void testBasicUsageWithQuery() throws Exception {
>> -    checkCorrectClassification(new CachingNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName,
>> "it")));
>> +    LeafReader leafReader = null;
>> +    try {
>> +      MockAnalyzer analyzer = new MockAnalyzer(random());
>> +      leafReader = populateSampleIndex(analyzer);
>> +      TermQuery query = new TermQuery(new Term(textFieldName, "it"));
>> +      checkCorrectClassification(new
>> CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>    @Test
>>    public void testNGramUsage() throws Exception {
>> -    checkCorrectClassification(new CachingNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName,
>> categoryFieldName);
>> +    LeafReader leafReader = null;
>> +    try {
>> +      NGramAnalyzer analyzer = new NGramAnalyzer();
>> +      leafReader = populateSampleIndex(analyzer);
>> +      checkCorrectClassification(new
>> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>    private class NGramAnalyzer extends Analyzer {
>> @@ -57,9 +87,4 @@ public class CachingNaiveBayesClassifier
>>      }
>>    }
>>
>> -  @Test
>> -  public void testPerformance() throws Exception {
>> -    checkPerformance(new CachingNaiveBayesClassifier(), new
>> MockAnalyzer(random()), categoryFieldName);
>> -  }
>> -
>>  }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
>> Thu Apr 30 14:12:03 2015
>> @@ -41,14 +41,14 @@ import org.junit.Before;
>>   */
>>  public abstract class ClassificationTestBase<T> extends LuceneTestCase {
>>    public final static String POLITICS_INPUT = "Here are some interesting
>> questions and answers about Mitt Romney.. " +
>> -      "If you don't know the answer to the question about Mitt Romney,
>> then simply click on the answer below the question section.";
>> +          "If you don't know the answer to the question about Mitt
>> Romney, then simply click on the answer below the question section.";
>>    public static final BytesRef POLITICS_RESULT = new
>> BytesRef("politics");
>>
>>    public static final String TECHNOLOGY_INPUT = "Much is made of what
>> the likes of Facebook, Google and Apple know about users." +
>> -      " Truth is, Amazon may know more.";
>> +          " Truth is, Amazon may know more.";
>>    public static final BytesRef TECHNOLOGY_RESULT = new
>> BytesRef("technology");
>>
>> -  private RandomIndexWriter indexWriter;
>> +  protected RandomIndexWriter indexWriter;
>>    private Directory dir;
>>    private FieldType ft;
>>
>> @@ -79,53 +79,34 @@ public abstract class ClassificationTest
>>      dir.close();
>>    }
>>
>> -  protected void checkCorrectClassification(Classifier<T> classifier,
>> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
>> String classFieldName) throws Exception {
>> -    checkCorrectClassification(classifier, inputDoc, expectedResult,
>> analyzer, textFieldName, classFieldName, null);
>> +  protected void checkCorrectClassification(Classifier<T> classifier,
>> String inputDoc, T expectedResult) throws Exception {
>> +    ClassificationResult<T> classificationResult =
>> classifier.assignClass(inputDoc);
>> +    assertNotNull(classificationResult.getAssignedClass());
>> +    assertEquals("got an assigned class of " +
>> classificationResult.getAssignedClass(), expectedResult,
>> classificationResult.getAssignedClass());
>> +    double score = classificationResult.getScore();
>> +    assertTrue("score should be between 0 and 1, got:" + score, score <=
>> 1 && score >= 0);
>>    }
>>
>> -  protected void checkCorrectClassification(Classifier<T> classifier,
>> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
>> String classFieldName, Query query) throws Exception {
>> -    LeafReader leafReader = null;
>> -    try {
>> -      populateSampleIndex(analyzer);
>> -      leafReader =
>> SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
>> -      classifier.train(leafReader, textFieldName, classFieldName,
>> analyzer, query);
>> -      ClassificationResult<T> classificationResult =
>> classifier.assignClass(inputDoc);
>> -      assertNotNull(classificationResult.getAssignedClass());
>> -      assertEquals("got an assigned class of " +
>> classificationResult.getAssignedClass(), expectedResult,
>> classificationResult.getAssignedClass());
>> -      double score = classificationResult.getScore();
>> -      assertTrue("score should be between 0 and 1, got:" + score, score
>> <= 1 && score >= 0);
>> -    } finally {
>> -      if (leafReader != null)
>> -        leafReader.close();
>> -    }
>> -  }
>>    protected void checkOnlineClassification(Classifier<T> classifier,
>> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
>> String classFieldName) throws Exception {
>>      checkOnlineClassification(classifier, inputDoc, expectedResult,
>> analyzer, textFieldName, classFieldName, null);
>>    }
>>
>>    protected void checkOnlineClassification(Classifier<T> classifier,
>> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
>> String classFieldName, Query query) throws Exception {
>> -    LeafReader leafReader = null;
>> -    try {
>> -      populateSampleIndex(analyzer);
>> -      leafReader =
>> SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
>> -      classifier.train(leafReader, textFieldName, classFieldName,
>> analyzer, query);
>> -      ClassificationResult<T> classificationResult =
>> classifier.assignClass(inputDoc);
>> -      assertNotNull(classificationResult.getAssignedClass());
>> -      assertEquals("got an assigned class of " +
>> classificationResult.getAssignedClass(), expectedResult,
>> classificationResult.getAssignedClass());
>> -      double score = classificationResult.getScore();
>> -      assertTrue("score should be between 0 and 1, got: " + score, score
>> <= 1 && score >= 0);
>> -      updateSampleIndex();
>> -      ClassificationResult<T> secondClassificationResult =
>> classifier.assignClass(inputDoc);
>> -      assertEquals(classificationResult.getAssignedClass(),
>> secondClassificationResult.getAssignedClass());
>> -      assertEquals(Double.valueOf(score),
>> Double.valueOf(secondClassificationResult.getScore()));
>> -
>> -    } finally {
>> -      if (leafReader != null)
>> -        leafReader.close();
>> -    }
>> +    populateSampleIndex(analyzer);
>> +
>> +    ClassificationResult<T> classificationResult =
>> classifier.assignClass(inputDoc);
>> +    assertNotNull(classificationResult.getAssignedClass());
>> +    assertEquals("got an assigned class of " +
>> classificationResult.getAssignedClass(), expectedResult,
>> classificationResult.getAssignedClass());
>> +    double score = classificationResult.getScore();
>> +    assertTrue("score should be between 0 and 1, got: " + score, score
>> <= 1 && score >= 0);
>> +    updateSampleIndex();
>> +    ClassificationResult<T> secondClassificationResult =
>> classifier.assignClass(inputDoc);
>> +    assertEquals(classificationResult.getAssignedClass(),
>> secondClassificationResult.getAssignedClass());
>> +    assertEquals(Double.valueOf(score),
>> Double.valueOf(secondClassificationResult.getScore()));
>> +
>>    }
>>
>> -  private void populateSampleIndex(Analyzer analyzer) throws IOException
>> {
>> +  protected LeafReader populateSampleIndex(Analyzer analyzer) throws
>> IOException {
>>      indexWriter.close();
>>      indexWriter = new RandomIndexWriter(random(), dir,
>> newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
>>      indexWriter.commit();
>> @@ -134,8 +115,8 @@ public abstract class ClassificationTest
>>
>>      Document doc = new Document();
>>      text = "The traveling press secretary for Mitt Romney lost his cool
>> and cursed at reporters " +
>> -        "who attempted to ask questions of the Republican presidential
>> candidate in a public plaza near the Tomb of " +
>> -        "the Unknown Soldier in Warsaw Tuesday.";
>> +            "who attempted to ask questions of the Republican
>> presidential candidate in a public plaza near the Tomb of " +
>> +            "the Unknown Soldier in Warsaw Tuesday.";
>>      doc.add(new Field(textFieldName, text, ft));
>>      doc.add(new Field(categoryFieldName, "politics", ft));
>>      doc.add(new Field(booleanFieldName, "true", ft));
>> @@ -144,7 +125,7 @@ public abstract class ClassificationTest
>>
>>      doc = new Document();
>>      text = "Mitt Romney seeks to assure Israel and Iran, as well as
>> Jewish voters in the United" +
>> -        " States, that he will be tougher against Iran's nuclear
>> ambitions than President Barack Obama.";
>> +            " States, that he will be tougher against Iran's nuclear
>> ambitions than President Barack Obama.";
>>      doc.add(new Field(textFieldName, text, ft));
>>      doc.add(new Field(categoryFieldName, "politics", ft));
>>      doc.add(new Field(booleanFieldName, "true", ft));
>> @@ -152,8 +133,8 @@ public abstract class ClassificationTest
>>
>>      doc = new Document();
>>      text = "And there's a threshold question that he has to answer for
>> the American people and " +
>> -        "that's whether he is prepared to be commander-in-chief,\" she
>> continued. \"As we look to the past events, we " +
>> -        "know that this raises some questions about his preparedness and
>> we'll see how the rest of his trip goes.\"";
>> +            "that's whether he is prepared to be commander-in-chief,\"
>> she continued. \"As we look to the past events, we " +
>> +            "know that this raises some questions about his preparedness
>> and we'll see how the rest of his trip goes.\"";
>>      doc.add(new Field(textFieldName, text, ft));
>>      doc.add(new Field(categoryFieldName, "politics", ft));
>>      doc.add(new Field(booleanFieldName, "true", ft));
>> @@ -161,8 +142,8 @@ public abstract class ClassificationTest
>>
>>      doc = new Document();
>>      text = "Still, when it comes to gun policy, many congressional
>> Democrats have \"decided to " +
>> -        "keep quiet and not go there,\" said Alan Lizotte, dean and
>> professor at the State University of New York at " +
>> -        "Albany's School of Criminal Justice.";
>> +            "keep quiet and not go there,\" said Alan Lizotte, dean and
>> professor at the State University of New York at " +
>> +            "Albany's School of Criminal Justice.";
>>      doc.add(new Field(textFieldName, text, ft));
>>      doc.add(new Field(categoryFieldName, "politics", ft));
>>      doc.add(new Field(booleanFieldName, "true", ft));
>> @@ -170,8 +151,8 @@ public abstract class ClassificationTest
>>
>>      doc = new Document();
>>      text = "Standing amongst the thousands of people at the state
>> Capitol, Jorstad, director of " +
>> -        "technology at the University of Wisconsin-La Crosse, documented
>> the historic moment and shared it with the " +
>> -        "world through the Internet.";
>> +            "technology at the University of Wisconsin-La Crosse,
>> documented the historic moment and shared it with the " +
>> +            "world through the Internet.";
>>      doc.add(new Field(textFieldName, text, ft));
>>      doc.add(new Field(categoryFieldName, "technology", ft));
>>      doc.add(new Field(booleanFieldName, "false", ft));
>> @@ -179,7 +160,7 @@ public abstract class ClassificationTest
>>
>>      doc = new Document();
>>      text = "So, about all those experts and analysts who've spent the
>> past year or so saying " +
>> -        "Facebook was going to make a phone. A new expert has stepped
>> forward to say it's not going to happen.";
>> +            "Facebook was going to make a phone. A new expert has
>> stepped forward to say it's not going to happen.";
>>      doc.add(new Field(textFieldName, text, ft));
>>      doc.add(new Field(categoryFieldName, "technology", ft));
>>      doc.add(new Field(booleanFieldName, "false", ft));
>> @@ -187,8 +168,8 @@ public abstract class ClassificationTest
>>
>>      doc = new Document();
>>      text = "More than 400 million people trust Google with their e-mail,
>> and 50 million store files" +
>> -        " in the cloud using the Dropbox service. People manage their
>> bank accounts, pay bills, trade stocks and " +
>> -        "generally transfer or store huge volumes of personal data
>> online.";
>> +            " in the cloud using the Dropbox service. People manage
>> their bank accounts, pay bills, trade stocks and " +
>> +            "generally transfer or store huge volumes of personal data
>> online.";
>>      doc.add(new Field(textFieldName, text, ft));
>>      doc.add(new Field(categoryFieldName, "technology", ft));
>>      doc.add(new Field(booleanFieldName, "false", ft));
>> @@ -200,22 +181,15 @@ public abstract class ClassificationTest
>>      indexWriter.addDocument(doc);
>>
>>      indexWriter.commit();
>> +    return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
>>    }
>>
>>    protected void checkPerformance(Classifier<T> classifier, Analyzer
>> analyzer, String classFieldName) throws Exception {
>> -    LeafReader leafReader = null;
>>      long trainStart = System.currentTimeMillis();
>> -    try {
>> -      populatePerformanceIndex(analyzer);
>> -      leafReader =
>> SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
>> -      classifier.train(leafReader, textFieldName, classFieldName,
>> analyzer);
>> -      long trainEnd = System.currentTimeMillis();
>> -      long trainTime = trainEnd - trainStart;
>> -      assertTrue("training took more than 2 mins : " + trainTime / 1000
>> + "s", trainTime < 120000);
>> -    } finally {
>> -      if (leafReader != null)
>> -        leafReader.close();
>> -    }
>> +    populatePerformanceIndex(analyzer);
>> +    long trainEnd = System.currentTimeMillis();
>> +    long trainTime = trainEnd - trainStart;
>> +    assertTrue("training took more than 2 mins : " + trainTime / 1000 +
>> "s", trainTime < 120000);
>>    }
>>
>>    private void populatePerformanceIndex(Analyzer analyzer) throws
>> IOException {
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
>> Thu Apr 30 14:12:03 2015
>> @@ -17,6 +17,8 @@
>>  package org.apache.lucene.classification;
>>
>>  import org.apache.lucene.analysis.MockAnalyzer;
>> +import org.apache.lucene.index.LeafReader;
>> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>>  import org.apache.lucene.index.Term;
>>  import org.apache.lucene.search.TermQuery;
>>  import org.apache.lucene.util.BytesRef;
>> @@ -29,20 +31,32 @@ public class KNearestNeighborClassifierT
>>
>>    @Test
>>    public void testBasicUsage() throws Exception {
>> -    // usage with default MLT min docs / term freq
>> -    checkCorrectClassification(new KNearestNeighborClassifier(3),
>> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName,
>> categoryFieldName);
>> -    // usage without custom min docs / term freq for MLT
>> -    checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName);
>> +    LeafReader leafReader = null;
>> +    try {
>> +      MockAnalyzer analyzer = new MockAnalyzer(random());
>> +      leafReader = populateSampleIndex(analyzer);
>> +      checkCorrectClassification(new
>> KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0,
>> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> +      checkCorrectClassification(new
>> KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1,
>> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>    @Test
>>    public void testBasicUsageWithQuery() throws Exception {
>> -    checkCorrectClassification(new KNearestNeighborClassifier(1),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName,
>> "it")));
>> -  }
>> -
>> -  @Test
>> -  public void testPerformance() throws Exception {
>> -    checkPerformance(new KNearestNeighborClassifier(100), new
>> MockAnalyzer(random()), categoryFieldName);
>> +    LeafReader leafReader = null;
>> +    try {
>> +      MockAnalyzer analyzer = new MockAnalyzer(random());
>> +      leafReader = populateSampleIndex(analyzer);
>> +      TermQuery query = new TermQuery(new Term(textFieldName, "it"));
>> +      checkCorrectClassification(new
>> KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0,
>> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>  }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
>> Thu Apr 30 14:12:03 2015
>> @@ -22,14 +22,13 @@ import org.apache.lucene.analysis.Tokeni
>>  import org.apache.lucene.analysis.core.KeywordTokenizer;
>>  import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
>>  import org.apache.lucene.analysis.reverse.ReverseStringFilter;
>> +import org.apache.lucene.index.LeafReader;
>> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>>  import org.apache.lucene.index.Term;
>>  import org.apache.lucene.search.TermQuery;
>>  import org.apache.lucene.util.BytesRef;
>> -import org.apache.lucene.util.LuceneTestCase;
>>  import org.junit.Test;
>>
>> -import java.io.Reader;
>> -
>>  /**
>>   * Testcase for {@link SimpleNaiveBayesClassifier}
>>   */
>> @@ -37,18 +36,46 @@ public class SimpleNaiveBayesClassifierT
>>
>>    @Test
>>    public void testBasicUsage() throws Exception {
>> -    checkCorrectClassification(new SimpleNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName);
>> -    checkCorrectClassification(new SimpleNaiveBayesClassifier(),
>> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName,
>> categoryFieldName);
>> +    LeafReader leafReader = null;
>> +    try {
>> +      MockAnalyzer analyzer = new MockAnalyzer(random());
>> +      leafReader = populateSampleIndex(analyzer);
>> +      checkCorrectClassification(new
>> SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> +      checkCorrectClassification(new
>> SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), POLITICS_INPUT, POLITICS_RESULT);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>    @Test
>>    public void testBasicUsageWithQuery() throws Exception {
>> -    checkCorrectClassification(new SimpleNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName,
>> "it")));
>> +    LeafReader leafReader = null;
>> +    try {
>> +      MockAnalyzer analyzer = new MockAnalyzer(random());
>> +      leafReader = populateSampleIndex(analyzer);
>> +      TermQuery query = new TermQuery(new Term(textFieldName, "it"));
>> +      checkCorrectClassification(new
>> SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>    @Test
>>    public void testNGramUsage() throws Exception {
>> -    checkCorrectClassification(new SimpleNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName,
>> categoryFieldName);
>> +    LeafReader leafReader = null;
>> +    try {
>> +      Analyzer analyzer = new NGramAnalyzer();
>> +      leafReader = populateSampleIndex(analyzer);
>> +      checkCorrectClassification(new
>> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> +    } finally {
>> +      if (leafReader != null) {
>> +        leafReader.close();
>> +      }
>> +    }
>>    }
>>
>>    private class NGramAnalyzer extends Analyzer {
>> @@ -59,9 +86,4 @@ public class SimpleNaiveBayesClassifierT
>>      }
>>    }
>>
>> -  @Test
>> -  public void testPerformance() throws Exception {
>> -    checkPerformance(new SimpleNaiveBayesClassifier(), new
>> MockAnalyzer(random()), categoryFieldName);
>> -  }
>> -
>>  }
>>
>>
>>
>