You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@lucene.apache.org by Tommaso Teofili <to...@gmail.com> on 2015/05/03 08:55:50 UTC
Re: svn commit: r1676997 - in /lucene/dev/trunk/lucene/classification/src:
java/org/apache/lucene/classification/ java/org/apache/lucene/classification/utils/
test/org/apache/lucene/classification/
sorry for the inconvenience, refactored APIs changed constructor but forgot
to update the related javadocs, it should be fixed now.
Tommaso
2015-04-30 18:28 GMT+02:00 Ryan Ernst <ry...@iernst.net>:
> Did you mean to remove javadocs on BooleanPerceptionClassifier? This
> breaks precommit...
>
> On Thu, Apr 30, 2015 at 7:12 AM, <to...@apache.org> wrote:
>
>> Author: tommaso
>> Date: Thu Apr 30 14:12:03 2015
>> New Revision: 1676997
>>
>> URL: http://svn.apache.org/r1676997
>> Log:
>> LUCENE-6045 - refactor Classifier API to work better with multithreading
>>
>> Modified:
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
>>
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
>>
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
>>
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
>>
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
>>
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
>>
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
>> Thu Apr 30 14:12:03 2015
>> @@ -58,76 +58,14 @@ import org.apache.lucene.util.fst.Util;
>> */
>> public class BooleanPerceptronClassifier implements Classifier<Boolean> {
>>
>> - private Double threshold;
>> - private final Integer batchSize;
>> - private Terms textTerms;
>> - private Analyzer analyzer;
>> - private String textFieldName;
>> + private final Double threshold;
>> + private final Terms textTerms;
>> + private final Analyzer analyzer;
>> + private final String textFieldName;
>> private FST<Long> fst;
>>
>> - /**
>> - * Create a {@link BooleanPerceptronClassifier}
>> - *
>> - * @param threshold the binary threshold for perceptron output
>> evaluation
>> - */
>> - public BooleanPerceptronClassifier(Double threshold, Integer
>> batchSize) {
>> - this.threshold = threshold;
>> - this.batchSize = batchSize;
>> - }
>> -
>> - /**
>> - * Default constructor, no batch updates of FST, perceptron threshold
>> is
>> - * calculated via underlying index metrics during
>> - * {@link #train(org.apache.lucene.index.LeafReader, String, String,
>> org.apache.lucene.analysis.Analyzer)
>> - * training}
>> - */
>> - public BooleanPerceptronClassifier() {
>> - batchSize = 1;
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public ClassificationResult<Boolean> assignClass(String text)
>> - throws IOException {
>> - if (textTerms == null) {
>> - throw new IOException("You must first call Classifier#train");
>> - }
>> - Long output = 0l;
>> - try (TokenStream tokenStream = analyzer.tokenStream(textFieldName,
>> text)) {
>> - CharTermAttribute charTermAttribute = tokenStream
>> - .addAttribute(CharTermAttribute.class);
>> - tokenStream.reset();
>> - while (tokenStream.incrementToken()) {
>> - String s = charTermAttribute.toString();
>> - Long d = Util.get(fst, new BytesRef(s));
>> - if (d != null) {
>> - output += d;
>> - }
>> - }
>> - tokenStream.end();
>> - }
>> -
>> - double score = 1 - Math.exp(-1 * Math.abs(threshold -
>> output.doubleValue()) / threshold);
>> - return new ClassificationResult<>(output >= threshold, score);
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String textFieldName,
>> - String classFieldName, Analyzer analyzer) throws
>> IOException {
>> - train(leafReader, textFieldName, classFieldName, analyzer, null);
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String textFieldName,
>> - String classFieldName, Analyzer analyzer, Query
>> query) throws IOException {
>> + public BooleanPerceptronClassifier(LeafReader leafReader, String
>> textFieldName, String classFieldName, Analyzer analyzer,
>> + Query query, Integer batchSize,
>> Double threshold) throws IOException {
>> this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
>>
>> if (textTerms == null) {
>> @@ -144,9 +82,11 @@ public class BooleanPerceptronClassifier
>> this.threshold = (double) sumDocFreq / 2d;
>> } else {
>> throw new IOException(
>> - "threshold cannot be assigned since term vectors for field "
>> - + textFieldName + " do not exist");
>> + "threshold cannot be assigned since term vectors for
>> field "
>> + + textFieldName + " do not exist");
>> }
>> + } else {
>> + this.threshold = threshold;
>> }
>>
>> // TODO : remove this map as soon as we have a writable FST
>> @@ -170,7 +110,7 @@ public class BooleanPerceptronClassifier
>> }
>> // run the search and use stored field values
>> for (ScoreDoc scoreDoc : indexSearcher.search(q,
>> - Integer.MAX_VALUE).scoreDocs) {
>> + Integer.MAX_VALUE).scoreDocs) {
>> StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
>>
>> StorableField textField = doc.getField(textFieldName);
>> @@ -187,7 +127,7 @@ public class BooleanPerceptronClassifier
>> long modifier = correctClass.compareTo(assignedClass);
>> if (modifier != 0) {
>> updateWeights(leafReader, scoreDoc.doc, assignedClass,
>> - weights, modifier, batchCount % batchSize == 0);
>> + weights, modifier, batchCount % batchSize == 0);
>> }
>> batchCount++;
>> }
>> @@ -195,11 +135,6 @@ public class BooleanPerceptronClassifier
>> weights.clear(); // free memory while waiting for GC
>> }
>>
>> - @Override
>> - public void train(LeafReader leafReader, String[] textFieldNames,
>> String classFieldName, Analyzer analyzer, Query query) throws IOException {
>> - throw new IOException("training with multiple fields not supported
>> by boolean perceptron classifier");
>> - }
>> -
>> private void updateWeights(LeafReader leafReader,
>> int docId, Boolean assignedClass,
>> SortedMap<String, Double> weights,
>> double modifier, boolean updateFST) throws
>> IOException {
>> @@ -210,7 +145,7 @@ public class BooleanPerceptronClassifier
>>
>> if (terms == null) {
>> throw new IOException("term vectors must be stored for field "
>> - + textFieldName);
>> + + textFieldName);
>> }
>>
>> TermsEnum termsEnum = terms.iterator();
>> @@ -240,17 +175,46 @@ public class BooleanPerceptronClassifier
>> for (Map.Entry<String, Double> entry : weights.entrySet()) {
>> scratchBytes.copyChars(entry.getKey());
>> fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts),
>> entry
>> - .getValue().longValue());
>> + .getValue().longValue());
>> }
>> fst = fstBuilder.finish();
>> }
>>
>> +
>> + /**
>> + * {@inheritDoc}
>> + */
>> + @Override
>> + public ClassificationResult<Boolean> assignClass(String text)
>> + throws IOException {
>> + if (textTerms == null) {
>> + throw new IOException("You must first call Classifier#train");
>> + }
>> + Long output = 0l;
>> + try (TokenStream tokenStream = analyzer.tokenStream(textFieldName,
>> text)) {
>> + CharTermAttribute charTermAttribute = tokenStream
>> + .addAttribute(CharTermAttribute.class);
>> + tokenStream.reset();
>> + while (tokenStream.incrementToken()) {
>> + String s = charTermAttribute.toString();
>> + Long d = Util.get(fst, new BytesRef(s));
>> + if (d != null) {
>> + output += d;
>> + }
>> + }
>> + tokenStream.end();
>> + }
>> +
>> + double score = 1 - Math.exp(-1 * Math.abs(threshold -
>> output.doubleValue()) / threshold);
>> + return new ClassificationResult<>(output >= threshold, score);
>> + }
>> +
>> /**
>> * {@inheritDoc}
>> */
>> @Override
>> public List<ClassificationResult<Boolean>> getClasses(String text)
>> - throws IOException {
>> + throws IOException {
>> throw new RuntimeException("not implemented");
>> }
>>
>> @@ -259,7 +223,7 @@ public class BooleanPerceptronClassifier
>> */
>> @Override
>> public List<ClassificationResult<Boolean>> getClasses(String text, int
>> max)
>> - throws IOException {
>> + throws IOException {
>> throw new RuntimeException("not implemented");
>> }
>>
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
>> Thu Apr 30 14:12:03 2015
>> @@ -49,50 +49,30 @@ import org.apache.lucene.util.BytesRef;
>> */
>> public class CachingNaiveBayesClassifier extends
>> SimpleNaiveBayesClassifier {
>> //for caching classes this will be the classification class list
>> - private ArrayList<BytesRef> cclasses = new ArrayList<>();
>> + private final ArrayList<BytesRef> cclasses = new ArrayList<>();
>> // it's a term-inmap style map, where the inmap contains class-hit
>> pairs to the
>> // upper term
>> - private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new
>> HashMap<>();
>> + private final Map<String, Map<BytesRef, Integer>> termCClassHitCache =
>> new HashMap<>();
>> // the term frequency in classes
>> - private Map<BytesRef, Double> classTermFreq = new HashMap<>();
>> + private final Map<BytesRef, Double> classTermFreq = new HashMap<>();
>> private boolean justCachedTerms;
>> private int docsWithClassSize;
>>
>> /**
>> - * Creates a new NaiveBayes classifier with inside caching. Note that
>> you must
>> - * call {@link #train(org.apache.lucene.index.LeafReader, String,
>> String, Analyzer) train()} before
>> - * you can classify any documents. If you want less memory usage you
>> could
>> + * Creates a new NaiveBayes classifier with inside caching. If you
>> want less memory usage you could
>> * call {@link #reInitCache(int, boolean) reInitCache()}.
>> */
>> - public CachingNaiveBayesClassifier() {
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer) throws IOException {
>> - train(leafReader, textFieldName, classFieldName, analyzer, null);
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer, Query query) throws IOException {
>> - train(leafReader, new String[]{textFieldName}, classFieldName,
>> analyzer, query);
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String[] textFieldNames,
>> String classFieldName, Analyzer analyzer, Query query) throws IOException {
>> - super.train(leafReader, textFieldNames, classFieldName, analyzer,
>> query);
>> + public CachingNaiveBayesClassifier(LeafReader leafReader, Analyzer
>> analyzer, Query query, String classFieldName, String... textFieldNames) {
>> + super(leafReader, analyzer, query, classFieldName, textFieldNames);
>> // building the cache
>> - reInitCache(0, true);
>> + try {
>> + reInitCache(0, true);
>> + } catch (IOException e) {
>> + throw new RuntimeException(e);
>> + }
>> }
>>
>> +
>> private List<ClassificationResult<BytesRef>>
>> assignClassNormalizedList(String inputDocument) throws IOException {
>> if (leafReader == null) {
>> throw new IOException("You must first call Classifier#train");
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
>> Thu Apr 30 14:12:03 2015
>> @@ -18,17 +18,19 @@ package org.apache.lucene.classification
>>
>> /**
>> * The result of a call to {@link Classifier#assignClass(String)}
>> holding an assigned class of type <code>T</code> and a score.
>> + *
>> * @lucene.experimental
>> */
>> -public class ClassificationResult<T> implements
>> Comparable<ClassificationResult<T>>{
>> +public class ClassificationResult<T> implements
>> Comparable<ClassificationResult<T>> {
>>
>> private final T assignedClass;
>> private double score;
>>
>> /**
>> * Constructor
>> + *
>> * @param assignedClass the class <code>T</code> assigned by a {@link
>> Classifier}
>> - * @param score the score for the assignedClass as a
>> <code>double</code>
>> + * @param score the score for the assignedClass as a
>> <code>double</code>
>> */
>> public ClassificationResult(T assignedClass, double score) {
>> this.assignedClass = assignedClass;
>> @@ -37,6 +39,7 @@ public class ClassificationResult<T> imp
>>
>> /**
>> * retrieve the result class
>> + *
>> * @return a <code>T</code> representing an assigned class
>> */
>> public T getAssignedClass() {
>> @@ -45,14 +48,16 @@ public class ClassificationResult<T> imp
>>
>> /**
>> * retrieve the result score
>> + *
>> * @return a <code>double</code> representing a result score
>> */
>> public double getScore() {
>> return score;
>> }
>> -
>> +
>> /**
>> * set the score value
>> + *
>> * @param score the score for the assignedClass as a
>> <code>double</code>
>> */
>> public void setScore(double score) {
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
>> Thu Apr 30 14:12:03 2015
>> @@ -22,7 +22,6 @@ import java.util.List;
>> import org.apache.lucene.analysis.Analyzer;
>> import org.apache.lucene.index.LeafReader;
>> import org.apache.lucene.search.Query;
>> -import org.apache.lucene.util.BytesRef;
>>
>> /**
>> * A classifier, see <code>
>> http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which
>> assign classes of type
>> @@ -39,7 +38,7 @@ public interface Classifier<T> {
>> * @return a {@link ClassificationResult} holding assigned class of
>> type <code>T</code> and score
>> * @throws IOException If there is a low-level I/O error.
>> */
>> - public ClassificationResult<T> assignClass(String text) throws
>> IOException;
>> + ClassificationResult<T> assignClass(String text) throws IOException;
>>
>> /**
>> * Get all the classes (sorted by score, descending) assigned to the
>> given text String.
>> @@ -48,7 +47,7 @@ public interface Classifier<T> {
>> * @return the whole list of {@link ClassificationResult}, the classes
>> and scores. Returns <code>null</code> if the classifier can't make lists.
>> * @throws IOException If there is a low-level I/O error.
>> */
>> - public List<ClassificationResult<T>> getClasses(String text) throws
>> IOException;
>> + List<ClassificationResult<T>> getClasses(String text) throws
>> IOException;
>>
>> /**
>> * Get the first <code>max</code> classes (sorted by score,
>> descending) assigned to the given text String.
>> @@ -58,44 +57,6 @@ public interface Classifier<T> {
>> * @return the whole list of {@link ClassificationResult}, the classes
>> and scores. Cut for "max" number of elements. Returns <code>null</code> if
>> the classifier can't make lists.
>> * @throws IOException If there is a low-level I/O error.
>> */
>> - public List<ClassificationResult<T>> getClasses(String text, int max)
>> throws IOException;
>> -
>> - /**
>> - * Train the classifier using the underlying Lucene index
>> - *
>> - * @param leafReader the reader to use to access the Lucene index
>> - * @param textFieldName the name of the field used to compare
>> documents
>> - * @param classFieldName the name of the field containing the class
>> assigned to documents
>> - * @param analyzer the analyzer used to tokenize / filter the
>> unseen text
>> - * @throws IOException If there is a low-level I/O error.
>> - */
>> - public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer)
>> - throws IOException;
>> -
>> - /**
>> - * Train the classifier using the underlying Lucene index
>> - *
>> - * @param leafReader the reader to use to access the Lucene index
>> - * @param textFieldName the name of the field used to compare
>> documents
>> - * @param classFieldName the name of the field containing the class
>> assigned to documents
>> - * @param analyzer the analyzer used to tokenize / filter the
>> unseen text
>> - * @param query the query to filter which documents use for
>> training
>> - * @throws IOException If there is a low-level I/O error.
>> - */
>> - public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer, Query query)
>> - throws IOException;
>> -
>> - /**
>> - * Train the classifier using the underlying Lucene index
>> - *
>> - * @param leafReader the reader to use to access the Lucene index
>> - * @param textFieldNames the names of the fields to be used to compare
>> documents
>> - * @param classFieldName the name of the field containing the class
>> assigned to documents
>> - * @param analyzer the analyzer used to tokenize / filter the
>> unseen text
>> - * @param query the query to filter which documents use for
>> training
>> - * @throws IOException If there is a low-level I/O error.
>> - */
>> - public void train(LeafReader leafReader, String[] textFieldNames,
>> String classFieldName, Analyzer analyzer, Query query)
>> - throws IOException;
>> + List<ClassificationResult<T>> getClasses(String text, int max) throws
>> IOException;
>>
>> }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
>> Thu Apr 30 14:12:03 2015
>> @@ -26,6 +26,7 @@ import java.util.Map;
>>
>> import org.apache.lucene.analysis.Analyzer;
>> import org.apache.lucene.index.LeafReader;
>> +import org.apache.lucene.index.StorableField;
>> import org.apache.lucene.index.Term;
>> import org.apache.lucene.queries.mlt.MoreLikeThis;
>> import org.apache.lucene.search.BooleanClause;
>> @@ -45,37 +46,31 @@ import org.apache.lucene.util.BytesRef;
>> */
>> public class KNearestNeighborClassifier implements Classifier<BytesRef> {
>>
>> - private MoreLikeThis mlt;
>> - private String[] textFieldNames;
>> - private String classFieldName;
>> - private IndexSearcher indexSearcher;
>> + private final MoreLikeThis mlt;
>> + private final String[] textFieldNames;
>> + private final String classFieldName;
>> + private final IndexSearcher indexSearcher;
>> private final int k;
>> - private Query query;
>> + private final Query query;
>>
>> - private int minDocsFreq;
>> - private int minTermFreq;
>> -
>> - /**
>> - * Create a {@link Classifier} using kNN algorithm
>> - *
>> - * @param k the number of neighbors to analyze as an <code>int</code>
>> - */
>> - public KNearestNeighborClassifier(int k) {
>> + public KNearestNeighborClassifier(LeafReader leafReader, Analyzer
>> analyzer, Query query, int k, int minDocsFreq,
>> + int minTermFreq, String
>> classFieldName, String... textFieldNames) {
>> + this.textFieldNames = textFieldNames;
>> + this.classFieldName = classFieldName;
>> + this.mlt = new MoreLikeThis(leafReader);
>> + this.mlt.setAnalyzer(analyzer);
>> + this.mlt.setFieldNames(textFieldNames);
>> + this.indexSearcher = new IndexSearcher(leafReader);
>> + if (minDocsFreq > 0) {
>> + mlt.setMinDocFreq(minDocsFreq);
>> + }
>> + if (minTermFreq > 0) {
>> + mlt.setMinTermFreq(minTermFreq);
>> + }
>> + this.query = query;
>> this.k = k;
>> }
>>
>> - /**
>> - * Create a {@link Classifier} using kNN algorithm
>> - *
>> - * @param k the number of neighbors to analyze as an
>> <code>int</code>
>> - * @param minDocsFreq the minimum number of docs frequency for MLT to
>> be set with {@link MoreLikeThis#setMinDocFreq(int)}
>> - * @param minTermFreq the minimum number of term frequency for MLT to
>> be set with {@link MoreLikeThis#setMinTermFreq(int)}
>> - */
>> - public KNearestNeighborClassifier(int k, int minDocsFreq, int
>> minTermFreq) {
>> - this.k = k;
>> - this.minDocsFreq = minDocsFreq;
>> - this.minTermFreq = minTermFreq;
>> - }
>>
>> /**
>> * {@inheritDoc}
>> @@ -136,12 +131,15 @@ public class KNearestNeighborClassifier
>> private List<ClassificationResult<BytesRef>>
>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
>> Map<BytesRef, Integer> classCounts = new HashMap<>();
>> for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
>> - BytesRef cl = new
>> BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
>> - Integer count = classCounts.get(cl);
>> - if (count != null) {
>> - classCounts.put(cl, count + 1);
>> - } else {
>> - classCounts.put(cl, 1);
>> + StorableField storableField =
>> indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
>> + if (storableField != null) {
>> + BytesRef cl = new BytesRef(storableField.stringValue());
>> + Integer count = classCounts.get(cl);
>> + if (count != null) {
>> + classCounts.put(cl, count + 1);
>> + } else {
>> + classCounts.put(cl, 1);
>> + }
>> }
>> }
>> List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
>> @@ -161,39 +159,4 @@ public class KNearestNeighborClassifier
>> return returnList;
>> }
>>
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer) throws IOException {
>> - train(leafReader, textFieldName, classFieldName, analyzer, null);
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer, Query query) throws IOException {
>> - train(leafReader, new String[]{textFieldName}, classFieldName,
>> analyzer, query);
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String[] textFieldNames,
>> String classFieldName, Analyzer analyzer, Query query) throws IOException {
>> - this.textFieldNames = textFieldNames;
>> - this.classFieldName = classFieldName;
>> - mlt = new MoreLikeThis(leafReader);
>> - mlt.setAnalyzer(analyzer);
>> - mlt.setFieldNames(textFieldNames);
>> - indexSearcher = new IndexSearcher(leafReader);
>> - if (minDocsFreq > 0) {
>> - mlt.setMinDocFreq(minDocsFreq);
>> - }
>> - if (minTermFreq > 0) {
>> - mlt.setMinTermFreq(minTermFreq);
>> - }
>> - this.query = query;
>> - }
>> }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
>> Thu Apr 30 14:12:03 2015
>> @@ -51,64 +51,38 @@ public class SimpleNaiveBayesClassifier
>> * {@link org.apache.lucene.index.LeafReader} used to access the
>> {@link org.apache.lucene.classification.Classifier}'s
>> * index
>> */
>> - protected LeafReader leafReader;
>> + protected final LeafReader leafReader;
>>
>> /**
>> * names of the fields to be used as input text
>> */
>> - protected String[] textFieldNames;
>> + protected final String[] textFieldNames;
>>
>> /**
>> * name of the field to be used as a class / category output
>> */
>> - protected String classFieldName;
>> + protected final String classFieldName;
>>
>> /**
>> * {@link org.apache.lucene.analysis.Analyzer} to be used for
>> tokenizing unseen input text
>> */
>> - protected Analyzer analyzer;
>> + protected final Analyzer analyzer;
>>
>> /**
>> * {@link org.apache.lucene.search.IndexSearcher} to run searches on
>> the index for retrieving frequencies
>> */
>> - protected IndexSearcher indexSearcher;
>> + protected final IndexSearcher indexSearcher;
>>
>> /**
>> * {@link org.apache.lucene.search.Query} used to eventually filter
>> the document set to be used to classify
>> */
>> - protected Query query;
>> + protected final Query query;
>>
>> /**
>> * Creates a new NaiveBayes classifier.
>> - * Note that you must call {@link
>> #train(org.apache.lucene.index.LeafReader, String, String, Analyzer)
>> train()} before you can
>> * classify any documents.
>> */
>> - public SimpleNaiveBayesClassifier() {
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer) throws IOException {
>> - train(leafReader, textFieldName, classFieldName, analyzer, null);
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String textFieldName, String
>> classFieldName, Analyzer analyzer, Query query)
>> - throws IOException {
>> - train(leafReader, new String[]{textFieldName}, classFieldName,
>> analyzer, query);
>> - }
>> -
>> - /**
>> - * {@inheritDoc}
>> - */
>> - @Override
>> - public void train(LeafReader leafReader, String[] textFieldNames,
>> String classFieldName, Analyzer analyzer, Query query)
>> - throws IOException {
>> + public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer
>> analyzer, Query query, String classFieldName, String... textFieldNames) {
>> this.leafReader = leafReader;
>> this.indexSearcher = new IndexSearcher(this.leafReader);
>> this.textFieldNames = textFieldNames;
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
>> Thu Apr 30 14:12:03 2015
>> @@ -18,7 +18,7 @@
>> /**
>> * Uses already seen data (the indexed documents) to classify new
>> documents.
>> * <p>
>> - * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
>> + * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
>> * Neighbor classifier and a Perceptron based classifier.
>> */
>> package org.apache.lucene.classification;
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
>> Thu Apr 30 14:12:03 2015
>> @@ -33,7 +33,8 @@ public class DocToDoubleVectorUtils {
>>
>> /**
>> * create a sparse <code>Double</code> vector given doc and field term
>> vectors using local frequency of the terms in the doc
>> - * @param docTerms term vectors for a given document
>> + *
>> + * @param docTerms term vectors for a given document
>> * @param fieldTerms field term vectors
>> * @return a sparse vector of <code>Double</code>s as an array
>> * @throws IOException in case accessing the underlying index fails
>> @@ -54,8 +55,7 @@ public class DocToDoubleVectorUtils {
>> if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) {
>> long termFreqLocal = docTermsEnum.totalTermFreq(); // the
>> total number of occurrences of this term in the given document
>> freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
>> - }
>> - else {
>> + } else {
>> freqVector[i] = 0d;
>> }
>> i++;
>> @@ -66,6 +66,7 @@ public class DocToDoubleVectorUtils {
>>
>> /**
>> * create a dense <code>Double</code> vector given doc and field term
>> vectors using local frequency of the terms in the doc
>> + *
>> * @param docTerms term vectors for a given document
>> * @return a dense vector of <code>Double</code>s as an array
>> * @throws IOException in case accessing the underlying index fails
>> @@ -73,16 +74,16 @@ public class DocToDoubleVectorUtils {
>> public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms)
>> throws IOException {
>> Double[] freqVector = null;
>> if (docTerms != null) {
>> - freqVector = new Double[(int) docTerms.size()];
>> - int i = 0;
>> - TermsEnum docTermsEnum = docTerms.iterator();
>> + freqVector = new Double[(int) docTerms.size()];
>> + int i = 0;
>> + TermsEnum docTermsEnum = docTerms.iterator();
>>
>> - while (docTermsEnum.next() != null) {
>> - long termFreqLocal = docTermsEnum.totalTermFreq(); // the
>> total number of occurrences of this term in the given document
>> - freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
>> - i++;
>> - }
>> + while (docTermsEnum.next() != null) {
>> + long termFreqLocal = docTermsEnum.totalTermFreq(); // the total
>> number of occurrences of this term in the given document
>> + freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
>> + i++;
>> + }
>> }
>> return freqVector;
>> -}
>> + }
>> }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
>> Thu Apr 30 14:12:03 2015
>> @@ -17,6 +17,8 @@
>> package org.apache.lucene.classification;
>>
>> import org.apache.lucene.analysis.MockAnalyzer;
>> +import org.apache.lucene.index.LeafReader;
>> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>> import org.apache.lucene.index.Term;
>> import org.apache.lucene.search.TermQuery;
>> import org.junit.Test;
>> @@ -28,22 +30,45 @@ public class BooleanPerceptronClassifier
>>
>> @Test
>> public void testBasicUsage() throws Exception {
>> - checkCorrectClassification(new BooleanPerceptronClassifier(),
>> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName,
>> booleanFieldName);
>> + LeafReader leafReader = null;
>> + try {
>> + MockAnalyzer analyzer = new MockAnalyzer(random());
>> + leafReader = populateSampleIndex(analyzer);
>> + checkCorrectClassification(new
>> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName,
>> analyzer, null, 1, null), TECHNOLOGY_INPUT, false);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> @Test
>> public void testExplicitThreshold() throws Exception {
>> - checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1),
>> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName,
>> booleanFieldName);
>> + LeafReader leafReader = null;
>> + try {
>> + MockAnalyzer analyzer = new MockAnalyzer(random());
>> + leafReader = populateSampleIndex(analyzer);
>> + checkCorrectClassification(new
>> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName,
>> analyzer, null, 1, 100d), TECHNOLOGY_INPUT, false);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> @Test
>> public void testBasicUsageWithQuery() throws Exception {
>> - checkCorrectClassification(new BooleanPerceptronClassifier(),
>> TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName,
>> booleanFieldName, new TermQuery(new Term(textFieldName, "it")));
>> - }
>> -
>> - @Test
>> - public void testPerformance() throws Exception {
>> - checkPerformance(new BooleanPerceptronClassifier(), new
>> MockAnalyzer(random()), booleanFieldName);
>> + TermQuery query = new TermQuery(new Term(textFieldName, "it"));
>> + LeafReader leafReader = null;
>> + try {
>> + MockAnalyzer analyzer = new MockAnalyzer(random());
>> + leafReader = populateSampleIndex(analyzer);
>> + checkCorrectClassification(new
>> BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName,
>> analyzer, query, 1, null), TECHNOLOGY_INPUT, false);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
>> Thu Apr 30 14:12:03 2015
>> @@ -23,6 +23,8 @@ import org.apache.lucene.analysis.Tokeni
>> import org.apache.lucene.analysis.core.KeywordTokenizer;
>> import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
>> import org.apache.lucene.analysis.reverse.ReverseStringFilter;
>> +import org.apache.lucene.index.LeafReader;
>> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>> import org.apache.lucene.index.Term;
>> import org.apache.lucene.search.TermQuery;
>> import org.apache.lucene.util.BytesRef;
>> @@ -35,18 +37,46 @@ public class CachingNaiveBayesClassifier
>>
>> @Test
>> public void testBasicUsage() throws Exception {
>> - checkCorrectClassification(new CachingNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName);
>> - checkCorrectClassification(new CachingNaiveBayesClassifier(),
>> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName,
>> categoryFieldName);
>> + LeafReader leafReader = null;
>> + try {
>> + MockAnalyzer analyzer = new MockAnalyzer(random());
>> + leafReader = populateSampleIndex(analyzer);
>> + checkCorrectClassification(new
>> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> + checkCorrectClassification(new
>> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), POLITICS_INPUT, POLITICS_RESULT);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> @Test
>> public void testBasicUsageWithQuery() throws Exception {
>> - checkCorrectClassification(new CachingNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName,
>> "it")));
>> + LeafReader leafReader = null;
>> + try {
>> + MockAnalyzer analyzer = new MockAnalyzer(random());
>> + leafReader = populateSampleIndex(analyzer);
>> + TermQuery query = new TermQuery(new Term(textFieldName, "it"));
>> + checkCorrectClassification(new
>> CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> @Test
>> public void testNGramUsage() throws Exception {
>> - checkCorrectClassification(new CachingNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName,
>> categoryFieldName);
>> + LeafReader leafReader = null;
>> + try {
>> + NGramAnalyzer analyzer = new NGramAnalyzer();
>> + leafReader = populateSampleIndex(analyzer);
>> + checkCorrectClassification(new
>> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> private class NGramAnalyzer extends Analyzer {
>> @@ -57,9 +87,4 @@ public class CachingNaiveBayesClassifier
>> }
>> }
>>
>> - @Test
>> - public void testPerformance() throws Exception {
>> - checkPerformance(new CachingNaiveBayesClassifier(), new
>> MockAnalyzer(random()), categoryFieldName);
>> - }
>> -
>> }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
>> Thu Apr 30 14:12:03 2015
>> @@ -41,14 +41,14 @@ import org.junit.Before;
>> */
>> public abstract class ClassificationTestBase<T> extends LuceneTestCase {
>> public final static String POLITICS_INPUT = "Here are some interesting
>> questions and answers about Mitt Romney.. " +
>> - "If you don't know the answer to the question about Mitt Romney,
>> then simply click on the answer below the question section.";
>> + "If you don't know the answer to the question about Mitt
>> Romney, then simply click on the answer below the question section.";
>> public static final BytesRef POLITICS_RESULT = new
>> BytesRef("politics");
>>
>> public static final String TECHNOLOGY_INPUT = "Much is made of what
>> the likes of Facebook, Google and Apple know about users." +
>> - " Truth is, Amazon may know more.";
>> + " Truth is, Amazon may know more.";
>> public static final BytesRef TECHNOLOGY_RESULT = new
>> BytesRef("technology");
>>
>> - private RandomIndexWriter indexWriter;
>> + protected RandomIndexWriter indexWriter;
>> private Directory dir;
>> private FieldType ft;
>>
>> @@ -79,53 +79,34 @@ public abstract class ClassificationTest
>> dir.close();
>> }
>>
>> - protected void checkCorrectClassification(Classifier<T> classifier,
>> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
>> String classFieldName) throws Exception {
>> - checkCorrectClassification(classifier, inputDoc, expectedResult,
>> analyzer, textFieldName, classFieldName, null);
>> + protected void checkCorrectClassification(Classifier<T> classifier,
>> String inputDoc, T expectedResult) throws Exception {
>> + ClassificationResult<T> classificationResult =
>> classifier.assignClass(inputDoc);
>> + assertNotNull(classificationResult.getAssignedClass());
>> + assertEquals("got an assigned class of " +
>> classificationResult.getAssignedClass(), expectedResult,
>> classificationResult.getAssignedClass());
>> + double score = classificationResult.getScore();
>> + assertTrue("score should be between 0 and 1, got:" + score, score <=
>> 1 && score >= 0);
>> }
>>
>> - protected void checkCorrectClassification(Classifier<T> classifier,
>> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
>> String classFieldName, Query query) throws Exception {
>> - LeafReader leafReader = null;
>> - try {
>> - populateSampleIndex(analyzer);
>> - leafReader =
>> SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
>> - classifier.train(leafReader, textFieldName, classFieldName,
>> analyzer, query);
>> - ClassificationResult<T> classificationResult =
>> classifier.assignClass(inputDoc);
>> - assertNotNull(classificationResult.getAssignedClass());
>> - assertEquals("got an assigned class of " +
>> classificationResult.getAssignedClass(), expectedResult,
>> classificationResult.getAssignedClass());
>> - double score = classificationResult.getScore();
>> - assertTrue("score should be between 0 and 1, got:" + score, score
>> <= 1 && score >= 0);
>> - } finally {
>> - if (leafReader != null)
>> - leafReader.close();
>> - }
>> - }
>> protected void checkOnlineClassification(Classifier<T> classifier,
>> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
>> String classFieldName) throws Exception {
>> checkOnlineClassification(classifier, inputDoc, expectedResult,
>> analyzer, textFieldName, classFieldName, null);
>> }
>>
>> protected void checkOnlineClassification(Classifier<T> classifier,
>> String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName,
>> String classFieldName, Query query) throws Exception {
>> - LeafReader leafReader = null;
>> - try {
>> - populateSampleIndex(analyzer);
>> - leafReader =
>> SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
>> - classifier.train(leafReader, textFieldName, classFieldName,
>> analyzer, query);
>> - ClassificationResult<T> classificationResult =
>> classifier.assignClass(inputDoc);
>> - assertNotNull(classificationResult.getAssignedClass());
>> - assertEquals("got an assigned class of " +
>> classificationResult.getAssignedClass(), expectedResult,
>> classificationResult.getAssignedClass());
>> - double score = classificationResult.getScore();
>> - assertTrue("score should be between 0 and 1, got: " + score, score
>> <= 1 && score >= 0);
>> - updateSampleIndex();
>> - ClassificationResult<T> secondClassificationResult =
>> classifier.assignClass(inputDoc);
>> - assertEquals(classificationResult.getAssignedClass(),
>> secondClassificationResult.getAssignedClass());
>> - assertEquals(Double.valueOf(score),
>> Double.valueOf(secondClassificationResult.getScore()));
>> -
>> - } finally {
>> - if (leafReader != null)
>> - leafReader.close();
>> - }
>> + populateSampleIndex(analyzer);
>> +
>> + ClassificationResult<T> classificationResult =
>> classifier.assignClass(inputDoc);
>> + assertNotNull(classificationResult.getAssignedClass());
>> + assertEquals("got an assigned class of " +
>> classificationResult.getAssignedClass(), expectedResult,
>> classificationResult.getAssignedClass());
>> + double score = classificationResult.getScore();
>> + assertTrue("score should be between 0 and 1, got: " + score, score
>> <= 1 && score >= 0);
>> + updateSampleIndex();
>> + ClassificationResult<T> secondClassificationResult =
>> classifier.assignClass(inputDoc);
>> + assertEquals(classificationResult.getAssignedClass(),
>> secondClassificationResult.getAssignedClass());
>> + assertEquals(Double.valueOf(score),
>> Double.valueOf(secondClassificationResult.getScore()));
>> +
>> }
>>
>> - private void populateSampleIndex(Analyzer analyzer) throws IOException
>> {
>> + protected LeafReader populateSampleIndex(Analyzer analyzer) throws
>> IOException {
>> indexWriter.close();
>> indexWriter = new RandomIndexWriter(random(), dir,
>> newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
>> indexWriter.commit();
>> @@ -134,8 +115,8 @@ public abstract class ClassificationTest
>>
>> Document doc = new Document();
>> text = "The traveling press secretary for Mitt Romney lost his cool
>> and cursed at reporters " +
>> - "who attempted to ask questions of the Republican presidential
>> candidate in a public plaza near the Tomb of " +
>> - "the Unknown Soldier in Warsaw Tuesday.";
>> + "who attempted to ask questions of the Republican
>> presidential candidate in a public plaza near the Tomb of " +
>> + "the Unknown Soldier in Warsaw Tuesday.";
>> doc.add(new Field(textFieldName, text, ft));
>> doc.add(new Field(categoryFieldName, "politics", ft));
>> doc.add(new Field(booleanFieldName, "true", ft));
>> @@ -144,7 +125,7 @@ public abstract class ClassificationTest
>>
>> doc = new Document();
>> text = "Mitt Romney seeks to assure Israel and Iran, as well as
>> Jewish voters in the United" +
>> - " States, that he will be tougher against Iran's nuclear
>> ambitions than President Barack Obama.";
>> + " States, that he will be tougher against Iran's nuclear
>> ambitions than President Barack Obama.";
>> doc.add(new Field(textFieldName, text, ft));
>> doc.add(new Field(categoryFieldName, "politics", ft));
>> doc.add(new Field(booleanFieldName, "true", ft));
>> @@ -152,8 +133,8 @@ public abstract class ClassificationTest
>>
>> doc = new Document();
>> text = "And there's a threshold question that he has to answer for
>> the American people and " +
>> - "that's whether he is prepared to be commander-in-chief,\" she
>> continued. \"As we look to the past events, we " +
>> - "know that this raises some questions about his preparedness and
>> we'll see how the rest of his trip goes.\"";
>> + "that's whether he is prepared to be commander-in-chief,\"
>> she continued. \"As we look to the past events, we " +
>> + "know that this raises some questions about his preparedness
>> and we'll see how the rest of his trip goes.\"";
>> doc.add(new Field(textFieldName, text, ft));
>> doc.add(new Field(categoryFieldName, "politics", ft));
>> doc.add(new Field(booleanFieldName, "true", ft));
>> @@ -161,8 +142,8 @@ public abstract class ClassificationTest
>>
>> doc = new Document();
>> text = "Still, when it comes to gun policy, many congressional
>> Democrats have \"decided to " +
>> - "keep quiet and not go there,\" said Alan Lizotte, dean and
>> professor at the State University of New York at " +
>> - "Albany's School of Criminal Justice.";
>> + "keep quiet and not go there,\" said Alan Lizotte, dean and
>> professor at the State University of New York at " +
>> + "Albany's School of Criminal Justice.";
>> doc.add(new Field(textFieldName, text, ft));
>> doc.add(new Field(categoryFieldName, "politics", ft));
>> doc.add(new Field(booleanFieldName, "true", ft));
>> @@ -170,8 +151,8 @@ public abstract class ClassificationTest
>>
>> doc = new Document();
>> text = "Standing amongst the thousands of people at the state
>> Capitol, Jorstad, director of " +
>> - "technology at the University of Wisconsin-La Crosse, documented
>> the historic moment and shared it with the " +
>> - "world through the Internet.";
>> + "technology at the University of Wisconsin-La Crosse,
>> documented the historic moment and shared it with the " +
>> + "world through the Internet.";
>> doc.add(new Field(textFieldName, text, ft));
>> doc.add(new Field(categoryFieldName, "technology", ft));
>> doc.add(new Field(booleanFieldName, "false", ft));
>> @@ -179,7 +160,7 @@ public abstract class ClassificationTest
>>
>> doc = new Document();
>> text = "So, about all those experts and analysts who've spent the
>> past year or so saying " +
>> - "Facebook was going to make a phone. A new expert has stepped
>> forward to say it's not going to happen.";
>> + "Facebook was going to make a phone. A new expert has
>> stepped forward to say it's not going to happen.";
>> doc.add(new Field(textFieldName, text, ft));
>> doc.add(new Field(categoryFieldName, "technology", ft));
>> doc.add(new Field(booleanFieldName, "false", ft));
>> @@ -187,8 +168,8 @@ public abstract class ClassificationTest
>>
>> doc = new Document();
>> text = "More than 400 million people trust Google with their e-mail,
>> and 50 million store files" +
>> - " in the cloud using the Dropbox service. People manage their
>> bank accounts, pay bills, trade stocks and " +
>> - "generally transfer or store huge volumes of personal data
>> online.";
>> + " in the cloud using the Dropbox service. People manage
>> their bank accounts, pay bills, trade stocks and " +
>> + "generally transfer or store huge volumes of personal data
>> online.";
>> doc.add(new Field(textFieldName, text, ft));
>> doc.add(new Field(categoryFieldName, "technology", ft));
>> doc.add(new Field(booleanFieldName, "false", ft));
>> @@ -200,22 +181,15 @@ public abstract class ClassificationTest
>> indexWriter.addDocument(doc);
>>
>> indexWriter.commit();
>> + return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
>> }
>>
>> protected void checkPerformance(Classifier<T> classifier, Analyzer
>> analyzer, String classFieldName) throws Exception {
>> - LeafReader leafReader = null;
>> long trainStart = System.currentTimeMillis();
>> - try {
>> - populatePerformanceIndex(analyzer);
>> - leafReader =
>> SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
>> - classifier.train(leafReader, textFieldName, classFieldName,
>> analyzer);
>> - long trainEnd = System.currentTimeMillis();
>> - long trainTime = trainEnd - trainStart;
>> - assertTrue("training took more than 2 mins : " + trainTime / 1000
>> + "s", trainTime < 120000);
>> - } finally {
>> - if (leafReader != null)
>> - leafReader.close();
>> - }
>> + populatePerformanceIndex(analyzer);
>> + long trainEnd = System.currentTimeMillis();
>> + long trainTime = trainEnd - trainStart;
>> + assertTrue("training took more than 2 mins : " + trainTime / 1000 +
>> "s", trainTime < 120000);
>> }
>>
>> private void populatePerformanceIndex(Analyzer analyzer) throws
>> IOException {
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
>> Thu Apr 30 14:12:03 2015
>> @@ -17,6 +17,8 @@
>> package org.apache.lucene.classification;
>>
>> import org.apache.lucene.analysis.MockAnalyzer;
>> +import org.apache.lucene.index.LeafReader;
>> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>> import org.apache.lucene.index.Term;
>> import org.apache.lucene.search.TermQuery;
>> import org.apache.lucene.util.BytesRef;
>> @@ -29,20 +31,32 @@ public class KNearestNeighborClassifierT
>>
>> @Test
>> public void testBasicUsage() throws Exception {
>> - // usage with default MLT min docs / term freq
>> - checkCorrectClassification(new KNearestNeighborClassifier(3),
>> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName,
>> categoryFieldName);
>> - // usage without custom min docs / term freq for MLT
>> - checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName);
>> + LeafReader leafReader = null;
>> + try {
>> + MockAnalyzer analyzer = new MockAnalyzer(random());
>> + leafReader = populateSampleIndex(analyzer);
>> + checkCorrectClassification(new
>> KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0,
>> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> + checkCorrectClassification(new
>> KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1,
>> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> @Test
>> public void testBasicUsageWithQuery() throws Exception {
>> - checkCorrectClassification(new KNearestNeighborClassifier(1),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName,
>> "it")));
>> - }
>> -
>> - @Test
>> - public void testPerformance() throws Exception {
>> - checkPerformance(new KNearestNeighborClassifier(100), new
>> MockAnalyzer(random()), categoryFieldName);
>> + LeafReader leafReader = null;
>> + try {
>> + MockAnalyzer analyzer = new MockAnalyzer(random());
>> + leafReader = populateSampleIndex(analyzer);
>> + TermQuery query = new TermQuery(new Term(textFieldName, "it"));
>> + checkCorrectClassification(new
>> KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0,
>> categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> }
>>
>> Modified:
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
>> URL:
>> http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
>>
>> ==============================================================================
>> ---
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
>> (original)
>> +++
>> lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
>> Thu Apr 30 14:12:03 2015
>> @@ -22,14 +22,13 @@ import org.apache.lucene.analysis.Tokeni
>> import org.apache.lucene.analysis.core.KeywordTokenizer;
>> import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
>> import org.apache.lucene.analysis.reverse.ReverseStringFilter;
>> +import org.apache.lucene.index.LeafReader;
>> +import org.apache.lucene.index.SlowCompositeReaderWrapper;
>> import org.apache.lucene.index.Term;
>> import org.apache.lucene.search.TermQuery;
>> import org.apache.lucene.util.BytesRef;
>> -import org.apache.lucene.util.LuceneTestCase;
>> import org.junit.Test;
>>
>> -import java.io.Reader;
>> -
>> /**
>> * Testcase for {@link SimpleNaiveBayesClassifier}
>> */
>> @@ -37,18 +36,46 @@ public class SimpleNaiveBayesClassifierT
>>
>> @Test
>> public void testBasicUsage() throws Exception {
>> - checkCorrectClassification(new SimpleNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName);
>> - checkCorrectClassification(new SimpleNaiveBayesClassifier(),
>> POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName,
>> categoryFieldName);
>> + LeafReader leafReader = null;
>> + try {
>> + MockAnalyzer analyzer = new MockAnalyzer(random());
>> + leafReader = populateSampleIndex(analyzer);
>> + checkCorrectClassification(new
>> SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> + checkCorrectClassification(new
>> SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), POLITICS_INPUT, POLITICS_RESULT);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> @Test
>> public void testBasicUsageWithQuery() throws Exception {
>> - checkCorrectClassification(new SimpleNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()),
>> textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName,
>> "it")));
>> + LeafReader leafReader = null;
>> + try {
>> + MockAnalyzer analyzer = new MockAnalyzer(random());
>> + leafReader = populateSampleIndex(analyzer);
>> + TermQuery query = new TermQuery(new Term(textFieldName, "it"));
>> + checkCorrectClassification(new
>> SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> @Test
>> public void testNGramUsage() throws Exception {
>> - checkCorrectClassification(new SimpleNaiveBayesClassifier(),
>> TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName,
>> categoryFieldName);
>> + LeafReader leafReader = null;
>> + try {
>> + Analyzer analyzer = new NGramAnalyzer();
>> + leafReader = populateSampleIndex(analyzer);
>> + checkCorrectClassification(new
>> CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName,
>> textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
>> + } finally {
>> + if (leafReader != null) {
>> + leafReader.close();
>> + }
>> + }
>> }
>>
>> private class NGramAnalyzer extends Analyzer {
>> @@ -59,9 +86,4 @@ public class SimpleNaiveBayesClassifierT
>> }
>> }
>>
>> - @Test
>> - public void testPerformance() throws Exception {
>> - checkPerformance(new SimpleNaiveBayesClassifier(), new
>> MockAnalyzer(random()), categoryFieldName);
>> - }
>> -
>> }
>>
>>
>>
>