You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2015/10/20 09:36:41 UTC
svn commit: r1709522 - in /lucene/dev/trunk/lucene/classification/src:
java/org/apache/lucene/classification/
java/org/apache/lucene/classification/document/
test/org/apache/lucene/classification/
test/org/apache/lucene/classification/document/
Author: tommaso
Date: Tue Oct 20 07:36:41 2015
New Revision: 1709522
URL: http://svn.apache.org/viewvc?rev=1709522&view=rev
Log:
LUCENE-6631 - added document classification api and impls
Added:
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/DocumentClassifier.java
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/package-info.java
lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/
lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java
lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java
lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java
Modified:
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java?rev=1709522&r1=1709521&r2=1709522&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java Tue Oct 20 07:36:41 2015
@@ -2,7 +2,6 @@ package org.apache.lucene.classification
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -81,38 +80,17 @@ public class CachingNaiveBayesClassifier
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
- String[] tokenizedDoc = tokenizeDoc(inputDocument);
+ String[] tokenizedText = tokenize(inputDocument);
- List<ClassificationResult<BytesRef>> dataList = calculateLogLikelihood(tokenizedDoc);
+ List<ClassificationResult<BytesRef>> assignedClasses = calculateLogLikelihood(tokenizedText);
// normalization
// The values transforms to a 0-1 range
- ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
- if (!dataList.isEmpty()) {
- Collections.sort(dataList);
- // this is a negative number closest to 0 = a
- double smax = dataList.get(0).getScore();
-
- double sumLog = 0;
- // log(sum(exp(x_n-a)))
- for (ClassificationResult<BytesRef> cr : dataList) {
- // getScore-smax <=0 (both negative, smax is the smallest abs()
- sumLog += Math.exp(cr.getScore() - smax);
- }
- // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
- double loga = smax;
- loga += Math.log(sumLog);
-
- // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
- for (ClassificationResult<BytesRef> cr : dataList) {
- returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
- }
- }
-
- return returnList;
+ ArrayList<ClassificationResult<BytesRef>> asignedClassesNorm = super.normClassificationResults(assignedClasses);
+ return asignedClassesNorm;
}
- private List<ClassificationResult<BytesRef>> calculateLogLikelihood(String[] tokenizedDoc) throws IOException {
+ private List<ClassificationResult<BytesRef>> calculateLogLikelihood(String[] tokenizedText) throws IOException {
// initialize the return List
ArrayList<ClassificationResult<BytesRef>> ret = new ArrayList<>();
for (BytesRef cclass : cclasses) {
@@ -120,7 +98,7 @@ public class CachingNaiveBayesClassifier
ret.add(cr);
}
// for each word
- for (String word : tokenizedDoc) {
+ for (String word : tokenizedText) {
// search with text:word for all class:c
Map<BytesRef, Integer> hitsInClasses = getWordFreqForClassess(word);
// for each class
Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1709522&r1=1709521&r2=1709522&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java Tue Oct 20 07:36:41 2015
@@ -48,12 +48,12 @@ import org.apache.lucene.util.BytesRef;
*/
public class KNearestNeighborClassifier implements Classifier<BytesRef> {
- private final MoreLikeThis mlt;
- private final String[] textFieldNames;
- private final String classFieldName;
- private final IndexSearcher indexSearcher;
- private final int k;
- private final Query query;
+ protected final MoreLikeThis mlt;
+ protected final String[] textFieldNames;
+ protected final String classFieldName;
+ protected final IndexSearcher indexSearcher;
+ protected final int k;
+ protected final Query query;
/**
* Creates a {@link KNearestNeighborClassifier}.
@@ -159,7 +159,7 @@ public class KNearestNeighborClassifier
}
//ranking of classes must be taken in consideration
- private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
+ protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
Map<BytesRef, Integer> classCounts = new HashMap<>();
Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
float maxScore = topDocs.getMaxScore();
Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1709522&r1=1709521&r2=1709522&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java Tue Oct 20 07:36:41 2015
@@ -85,8 +85,9 @@ public class SimpleNaiveBayesClassifier
* @param analyzer an {@link Analyzer} used to analyze unseen text
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
* if all the indexed docs should be used
- * @param classFieldName the name of the field used as the output for the classifier
- * @param textFieldNames the name of the fields used as the inputs for the classifier
+ * @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
+ * as the returned class will be a token indexed for this field
+ * @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
*/
public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
this.leafReader = leafReader;
@@ -102,16 +103,16 @@ public class SimpleNaiveBayesClassifier
*/
@Override
public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
- List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(inputDocument);
- ClassificationResult<BytesRef> retval = null;
+ List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(inputDocument);
+ ClassificationResult<BytesRef> assignedClass = null;
double maxscore = -Double.MAX_VALUE;
- for (ClassificationResult<BytesRef> element : doclist) {
- if (element.getScore() > maxscore) {
- retval = element;
- maxscore = element.getScore();
+ for (ClassificationResult<BytesRef> c : assignedClasses) {
+ if (c.getScore() > maxscore) {
+ assignedClass = c;
+ maxscore = c.getScore();
}
}
- return retval;
+ return assignedClass;
}
/**
@@ -119,9 +120,9 @@ public class SimpleNaiveBayesClassifier
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
- List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
- Collections.sort(doclist);
- return doclist;
+ List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
+ Collections.sort(assignedClasses);
+ return assignedClasses;
}
/**
@@ -129,9 +130,9 @@ public class SimpleNaiveBayesClassifier
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
- List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
- Collections.sort(doclist);
- return doclist.subList(0, max);
+ List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
+ Collections.sort(assignedClasses);
+ return assignedClasses.subList(0, max);
}
/**
@@ -141,46 +142,26 @@ public class SimpleNaiveBayesClassifier
* @throws IOException if assigning probabilities fails
*/
protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
- List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();
+ List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
- Terms terms = MultiFields.getTerms(leafReader, classFieldName);
- TermsEnum termsEnum = terms.iterator();
+ Terms classes = MultiFields.getTerms(leafReader, classFieldName);
+ TermsEnum classesEnum = classes.iterator();
BytesRef next;
- String[] tokenizedDoc = tokenizeDoc(inputDocument);
+ String[] tokenizedText = tokenize(inputDocument);
int docsWithClassSize = countDocsWithClass();
- while ((next = termsEnum.next()) != null) {
+ while ((next = classesEnum.next()) != null) {
if (next.length > 0) {
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
next = BytesRef.deepCopyOf(next);
- double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
- dataList.add(new ClassificationResult<>(next, clVal));
+ double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedText, next, docsWithClassSize);
+ assignedClasses.add(new ClassificationResult<>(next, clVal));
}
}
// normalization; the values transforms to a 0-1 range
- ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
- if (!dataList.isEmpty()) {
- Collections.sort(dataList);
- // this is a negative number closest to 0 = a
- double smax = dataList.get(0).getScore();
-
- double sumLog = 0;
- // log(sum(exp(x_n-a)))
- for (ClassificationResult<BytesRef> cr : dataList) {
- // getScore-smax <=0 (both negative, smax is the smallest abs()
- sumLog += Math.exp(cr.getScore() - smax);
- }
- // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
- double loga = smax;
- loga += Math.log(sumLog);
+ ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
- // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
- for (ClassificationResult<BytesRef> cr : dataList) {
- returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
- }
- }
-
- return returnList;
+ return assignedClassesNorm;
}
/**
@@ -192,15 +173,15 @@ public class SimpleNaiveBayesClassifier
protected int countDocsWithClass() throws IOException {
int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
if (docCount == -1) { // in case codec doesn't support getDocCount
- TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
+ TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
BooleanQuery.Builder q = new BooleanQuery.Builder();
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
if (query != null) {
q.add(query, BooleanClause.Occur.MUST);
}
indexSearcher.search(q.build(),
- totalHitCountCollector);
- docCount = totalHitCountCollector.getTotalHits();
+ classQueryCountCollector);
+ docCount = classQueryCountCollector.getTotalHits();
}
return docCount;
}
@@ -208,14 +189,14 @@ public class SimpleNaiveBayesClassifier
/**
* tokenize a <code>String</code> on this classifier's text fields and analyzer
*
- * @param doc the <code>String</code> representing an input text (to be classified)
+ * @param text the <code>String</code> representing an input text (to be classified)
* @return a <code>String</code> array of the resulting tokens
* @throws IOException if tokenization fails
*/
- protected String[] tokenizeDoc(String doc) throws IOException {
+ protected String[] tokenize(String text) throws IOException {
Collection<String> result = new LinkedList<>();
for (String textFieldName : textFieldNames) {
- try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
+ try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
@@ -227,18 +208,18 @@ public class SimpleNaiveBayesClassifier
return result.toArray(new String[result.size()]);
}
- private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c, int docsWithClassSize) throws IOException {
+ private double calculateLogLikelihood(String[] tokenizedText, BytesRef c, int docsWithClass) throws IOException {
// for each word
double result = 0d;
- for (String word : tokenizedDoc) {
+ for (String word : tokenizedText) {
// search with text:word AND class:c
- int hits = getWordFreqForClass(word, c);
+ int hits = getWordFreqForClass(word,c);
// num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
- double den = getTextTermFreqForClass(c) + docsWithClassSize;
+ double den = getTextTermFreqForClass(c) + docsWithClass;
// P(w|c) = num/den
double wordProbability = num / den;
@@ -249,6 +230,12 @@ public class SimpleNaiveBayesClassifier
return result;
}
+ /**
+ * Returns the average number of unique terms times the number of docs belonging to the input class
+ * @param c the class
+ * @return the average number of unique terms
+ * @throws IOException if a low level I/O problem happens
+ */
private double getTextTermFreqForClass(BytesRef c) throws IOException {
double avgNumberOfUniqueTerms = 0;
for (String textFieldName : textFieldNames) {
@@ -260,6 +247,14 @@ public class SimpleNaiveBayesClassifier
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
}
+ /**
+ * Returns the number of documents of the input class ( from the whole index or from a subset)
+ * that contains the word ( in a specific field or in all the fields if no one selected)
+ * @param word the token produced by the analyzer
+ * @param c the class
+ * @return the number of documents of the input class
+ * @throws IOException if a low level I/O problem happens
+ */
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
@@ -283,4 +278,36 @@ public class SimpleNaiveBayesClassifier
private int docCount(BytesRef countedClass) throws IOException {
return leafReader.docFreq(new Term(classFieldName, countedClass));
}
+
+ /**
+ * Normalize the classification results based on the max score available
+ * @param assignedClasses the list of assigned classes
+ * @return the normalized results
+ */
+ protected ArrayList<ClassificationResult<BytesRef>> normClassificationResults(List<ClassificationResult<BytesRef>> assignedClasses) {
+ // normalization; the values transforms to a 0-1 range
+ ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
+ if (!assignedClasses.isEmpty()) {
+ Collections.sort(assignedClasses);
+ // this is a negative number closest to 0 = a
+ double smax = assignedClasses.get(0).getScore();
+
+ double sumLog = 0;
+ // log(sum(exp(x_n-a)))
+ for (ClassificationResult<BytesRef> cr : assignedClasses) {
+ // getScore-smax <=0 (both negative, smax is the smallest abs()
+ sumLog += Math.exp(cr.getScore() - smax);
+ }
+ // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
+ double loga = smax;
+ loga += Math.log(sumLog);
+
+ // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
+ for (ClassificationResult<BytesRef> cr : assignedClasses) {
+ double scoreDiff = cr.getScore() - loga;
+ returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff)));
+ }
+ }
+ return returnList;
+ }
}
Added: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/DocumentClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/DocumentClassifier.java?rev=1709522&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/DocumentClassifier.java (added)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/DocumentClassifier.java Tue Oct 20 07:36:41 2015
@@ -0,0 +1,61 @@
+package org.apache.lucene.classification.document;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.lucene.classification.ClassificationResult;
+import org.apache.lucene.document.Document;
+
+/**
+ * A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
+ * <code>T</code> to a {@link org.apache.lucene.document.Document}s
+ *
+ * @lucene.experimental
+ */
+public interface DocumentClassifier<T> {
+ /**
+ * Assign a class (with score) to the given {@link org.apache.lucene.document.Document}
+ *
+ * @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
+ * @return a {@link org.apache.lucene.classification.ClassificationResult} holding assigned class of type <code>T</code> and score
+ * @throws java.io.IOException If there is a low-level I/O error.
+ */
+ ClassificationResult<T> assignClass(Document document) throws IOException;
+
+ /**
+ * Get all the classes (sorted by score, descending) assigned to the given {@link org.apache.lucene.document.Document}.
+ *
+ * @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
+ * @return the whole list of {@link org.apache.lucene.classification.ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
+ * @throws java.io.IOException If there is a low-level I/O error.
+ */
+ List<ClassificationResult<T>> getClasses(Document document) throws IOException;
+
+ /**
+ * Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
+ *
+ * @param document a {@link org.apache.lucene.document.Document} to be classified. Fields are considered features for the classification.
+ * @param max the number of return list elements
+ * @return the whole list of {@link org.apache.lucene.classification.ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
+ * @throws java.io.IOException If there is a low-level I/O error.
+ */
+ List<ClassificationResult<T>> getClasses(Document document, int max) throws IOException;
+
+}
Added: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java?rev=1709522&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java (added)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifier.java Tue Oct 20 07:36:41 2015
@@ -0,0 +1,146 @@
+package org.apache.lucene.classification.document;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.io.StringReader;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.classification.ClassificationResult;
+import org.apache.lucene.classification.KNearestNeighborClassifier;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.WildcardQuery;
+import org.apache.lucene.search.similarities.Similarity;
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * A k-Nearest Neighbor Document classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors</code>) based
+ * on {@link org.apache.lucene.queries.mlt.MoreLikeThis} .
+ *
+ * @lucene.experimental
+ */
+public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifier implements DocumentClassifier<BytesRef> {
+ protected Map<String, Analyzer> field2analyzer;
+
+ /**
+ * Creates a {@link KNearestNeighborClassifier}.
+ *
+ * @param leafReader the reader on the index to be used for classification
+ * @param similarity the {@link Similarity} to be used by the underlying {@link IndexSearcher} or {@code null}
+ * (defaults to {@link org.apache.lucene.search.similarities.ClassicSimilarity})
+ * @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
+ * if all the indexed docs should be used
+ * @param k the no. of docs to select in the MLT results to find the nearest neighbor
+ * @param minDocsFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq} parameter
+ * @param minTermFreq {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq} parameter
+ * @param classFieldName the name of the field used as the output for the classifier
+ * @param field2analyzer map with key a field name and the related {org.apache.lucene.analysis.Analyzer}
+ * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
+ */
+ public KNearestNeighborDocumentClassifier(LeafReader leafReader, Similarity similarity, Query query, int k, int minDocsFreq,
+ int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
+ super(leafReader,similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
+ this.field2analyzer = field2analyzer;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
+ TopDocs knnResults = knnSearch(document);
+ List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
+ ClassificationResult<BytesRef> assignedClass = null;
+ double maxscore = -Double.MAX_VALUE;
+ for (ClassificationResult<BytesRef> cl : assignedClasses) {
+ if (cl.getScore() > maxscore) {
+ assignedClass = cl;
+ maxscore = cl.getScore();
+ }
+ }
+ return assignedClass;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
+ TopDocs knnResults = knnSearch(document);
+ List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
+ Collections.sort(assignedClasses);
+ return assignedClasses;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
+ TopDocs knnResults = knnSearch(document);
+ List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
+ Collections.sort(assignedClasses);
+ return assignedClasses.subList(0, max);
+ }
+
+ /**
+ * Returns the top k results from a More Like This query based on the input document
+ *
+ * @param document the document to use for More Like This search
+ * @return the top results for the MLT query
+ * @throws IOException If there is a low-level I/O error
+ */
+ private TopDocs knnSearch(Document document) throws IOException {
+ BooleanQuery.Builder mltQuery = new BooleanQuery.Builder();
+
+ for (String fieldName : textFieldNames) {
+ String boost = null;
+ if (fieldName.contains("^")) {
+ String[] field2boost = fieldName.split("\\^");
+ fieldName = field2boost[0];
+ boost = field2boost[1];
+ }
+ String[] fieldValues = document.getValues(fieldName);
+ if (boost != null) {
+ mlt.setBoost(true);
+ mlt.setBoostFactor(Float.parseFloat(boost));
+ }
+ mlt.setAnalyzer(field2analyzer.get(fieldName));
+ for (String fieldContent : fieldValues) {
+ mltQuery.add(new BooleanClause(mlt.like(fieldName, new StringReader(fieldContent)), BooleanClause.Occur.SHOULD));
+ }
+ mlt.setBoost(false);
+ }
+ Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
+ mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
+ if (query != null) {
+ mltQuery.add(query, BooleanClause.Occur.MUST);
+ }
+ return indexSearcher.search(mltQuery.build(), k);
+ }
+}
Added: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java?rev=1709522&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java (added)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.java Tue Oct 20 07:36:41 2015
@@ -0,0 +1,289 @@
+package org.apache.lucene.classification.document;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.lucene.classification.ClassificationResult;
+import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.TotalHitCountCollector;
+import org.apache.lucene.search.WildcardQuery;
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * A simplistic Lucene based NaiveBayes classifier, see {@code http://en.wikipedia.org/wiki/Naive_Bayes_classifier}
+ *
+ * @lucene.experimental
+ */
+public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifier implements DocumentClassifier<BytesRef> {
+ /**
+ * {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing document fields
+ */
+ protected Map<String, Analyzer> field2analyzer;
+
+ /**
+ * Creates a new NaiveBayes classifier.
+ *
+ * @param leafReader the reader on the index to be used for classification
+ * @param query a {@link org.apache.lucene.search.Query} to eventually filter the docs used for training the classifier, or {@code null}
+ * if all the indexed docs should be used
+ * @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
+ * as the returned class will be a token indexed for this field
+ * @param textFieldNames the name of the fields used as the inputs for the classifier, they can contain boosting indication e.g. title^10
+ */
+ public SimpleNaiveBayesDocumentClassifier(LeafReader leafReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String... textFieldNames) {
+ super(leafReader, null, query, classFieldName, textFieldNames);
+ this.field2analyzer = field2analyzer;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
+ List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
+ ClassificationResult<BytesRef> assignedClass = null;
+ double maxscore = -Double.MAX_VALUE;
+ for (ClassificationResult<BytesRef> c : assignedClasses) {
+ if (c.getScore() > maxscore) {
+ assignedClass = c;
+ maxscore = c.getScore();
+ }
+ }
+ return assignedClass;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
+ List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
+ Collections.sort(assignedClasses);
+ return assignedClasses;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
+ List<ClassificationResult<BytesRef>> assignedClasses = assignNormClasses(document);
+ Collections.sort(assignedClasses);
+ return assignedClasses.subList(0, max);
+ }
+
+ private List<ClassificationResult<BytesRef>> assignNormClasses(Document inputDocument) throws IOException {
+ List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
+ Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
+ Map<String, Float> fieldName2boost = new LinkedHashMap<>();
+ Terms classes = MultiFields.getTerms(leafReader, classFieldName);
+ TermsEnum classesEnum = classes.iterator();
+ BytesRef c;
+
+ analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost);
+
+ int docsWithClassSize = countDocsWithClass();
+ while ((c = classesEnum.next()) != null) {
+ double classScore = 0;
+ for (String fieldName : textFieldNames) {
+ List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName);
+ double fieldScore = 0;
+ for (String[] fieldTokensArray : tokensArrays) {
+ fieldScore += calculateLogPrior(c, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, c, docsWithClassSize) * fieldName2boost.get(fieldName);
+ }
+ classScore += fieldScore;
+ }
+ assignedClasses.add(new ClassificationResult<>(BytesRef.deepCopyOf(c), classScore));
+ }
+ ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
+ return assignedClassesNorm;
+ }
+
+ /**
+ * This methods performs the analysis for the seed document and extract the boosts if present.
+ * This is done only one time for the Seed Document.
+ *
+ * @param inputDocument the seed unseen document
+ * @param fieldName2tokensArray a map that associated to a field name the list of token arrays for all its values
+ * @param fieldName2boost a map that associates the boost to the field
+ * @throws IOException If there is a low-level I/O error
+ */
+ private void analyzeSeedDocument(Document inputDocument, Map<String, List<String[]>> fieldName2tokensArray, Map<String, Float> fieldName2boost) throws IOException {
+ for (int i = 0; i < textFieldNames.length; i++) {
+ String fieldName = textFieldNames[i];
+ float boost = 1;
+ List<String[]> tokenizedValues = new LinkedList<>();
+ if (fieldName.contains("^")) {
+ String[] field2boost = fieldName.split("\\^");
+ fieldName = field2boost[0];
+ boost = Float.parseFloat(field2boost[1]);
+ }
+ Field[] fieldValues = inputDocument.getFields(fieldName);
+ for (Field fieldValue : fieldValues) {
+ TokenStream fieldTokens = fieldValue.tokenStream(field2analyzer.get(fieldName), null);
+ String[] fieldTokensArray = getTokenArray(fieldTokens);
+ tokenizedValues.add(fieldTokensArray);
+ }
+ fieldName2tokensArray.put(fieldName, tokenizedValues);
+ fieldName2boost.put(fieldName, boost);
+ textFieldNames[i] = fieldName;
+ }
+ }
+
+ /**
+ * Counts the number of documents in the index having at least a value for the 'class' field
+ *
+ * @return the no. of documents having a value for the 'class' field
+ * @throws java.io.IOException If accessing to term vectors or search fails
+ */
+ protected int countDocsWithClass() throws IOException {
+ int docCount = MultiFields.getTerms(this.leafReader, this.classFieldName).getDocCount();
+ if (docCount == -1) { // in case codec doesn't support getDocCount
+ TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
+ BooleanQuery.Builder q = new BooleanQuery.Builder();
+ q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
+ if (query != null) {
+ q.add(query, BooleanClause.Occur.MUST);
+ }
+ indexSearcher.search(q.build(),
+ classQueryCountCollector);
+ docCount = classQueryCountCollector.getTotalHits();
+ }
+ return docCount;
+ }
+
+ /**
+ * Returns a token array from the {@link org.apache.lucene.analysis.TokenStream} in input
+ *
+ * @param tokenizedText the tokenized content of a field
+ * @return a {@code String} array of the resulting tokens
+ * @throws java.io.IOException If tokenization fails because there is a low-level I/O error
+ */
+ protected String[] getTokenArray(TokenStream tokenizedText) throws IOException {
+ Collection<String> tokens = new LinkedList<>();
+ CharTermAttribute charTermAttribute = tokenizedText.addAttribute(CharTermAttribute.class);
+ tokenizedText.reset();
+ while (tokenizedText.incrementToken()) {
+ tokens.add(charTermAttribute.toString());
+ }
+ tokenizedText.end();
+ tokenizedText.close();
+ return tokens.toArray(new String[tokens.size()]);
+ }
+
+ /**
+ * @param tokenizedText the tokenized content of a field
+ * @param fieldName the input field name
+ * @param c the class to calculate the score of
+ * @param docsWithClass the total number of docs that have a class
+ * @return a normalized score for the class
+ * @throws IOException If there is a low-level I/O error
+ */
+ private double calculateLogLikelihood(String[] tokenizedText, String fieldName, BytesRef c, int docsWithClass) throws IOException {
+ // for each word
+ double result = 0d;
+ for (String word : tokenizedText) {
+ // search with text:word AND class:c
+ int hits = getWordFreqForClass(word, fieldName, c);
+
+ // num : count the no of times the word appears in documents of class c (+1)
+ double num = hits + 1; // +1 is added because of add 1 smoothing
+
+ // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
+ double den = getTextTermFreqForClass(c, fieldName) + docsWithClass;
+
+ // P(w|c) = num/den
+ double wordProbability = num / den;
+ result += Math.log(wordProbability);
+ }
+
+ // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c))
+ double normScore = result / (tokenizedText.length); // this is normalized because if not, long text fields will always be more important than short fields
+ return normScore;
+ }
+
+ /**
+ * Returns the average number of unique terms times the number of docs belonging to the input class
+ *
+ * @param c the class
+ * @return the average number of unique terms
+ * @throws java.io.IOException If there is a low-level I/O error
+ */
+ private double getTextTermFreqForClass(BytesRef c, String fieldName) throws IOException {
+ double avgNumberOfUniqueTerms;
+ Terms terms = MultiFields.getTerms(leafReader, fieldName);
+ long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
+ avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
+ int docsWithC = leafReader.docFreq(new Term(classFieldName, c));
+ return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
+ }
+
+ /**
+ * Returns the number of documents of the input class ( from the whole index or from a subset)
+ * that contains the word ( in a specific field or in all the fields if no one selected)
+ *
+ * @param word the token produced by the analyzer
+ * @param fieldName the field the word is coming from
+ * @param c the class
+ * @return number of documents of the input class
+ * @throws java.io.IOException If there is a low-level I/O error
+ */
+ private int getWordFreqForClass(String word, String fieldName, BytesRef c) throws IOException {
+ BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
+ BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
+ subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD));
+ booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
+ booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
+ if (query != null) {
+ booleanQuery.add(query, BooleanClause.Occur.MUST);
+ }
+ TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
+ indexSearcher.search(booleanQuery.build(), totalHitCountCollector);
+ return totalHitCountCollector.getTotalHits();
+ }
+
+ private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
+ return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
+ }
+
+ private int docCount(BytesRef countedClass) throws IOException {
+ return leafReader.docFreq(new Term(classFieldName, countedClass));
+ }
+}
Added: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/package-info.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/package-info.java?rev=1709522&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/package-info.java (added)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/document/package-info.java Tue Oct 20 07:36:41 2015
@@ -0,0 +1,7 @@
+/**
+ * Uses already seen data (the indexed documents) to classify new documents.
+ * <p>
+ * Currently contains a (simplistic) Naive Bayes classifier and a k-Nearest
+ * Neighbor classifier.
+ */
+package org.apache.lucene.classification.document;
Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java?rev=1709522&r1=1709521&r2=1709522&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java Tue Oct 20 07:36:41 2015
@@ -16,7 +16,7 @@
*/
/**
- * Uses already seen data (the indexed documents) to classify new documents.
+ * Uses already seen data (the indexed documents) to classify an input ( can be simple text or a structured document).
* <p>
* Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
* Neighbor classifier and a Perceptron based classifier.
Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1709522&r1=1709521&r2=1709522&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java Tue Oct 20 07:36:41 2015
@@ -57,8 +57,8 @@ public abstract class ClassificationTest
protected static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
protected RandomIndexWriter indexWriter;
- private Directory dir;
- private FieldType ft;
+ protected Directory dir;
+ protected FieldType ft;
protected String textFieldName;
protected String categoryFieldName;
Added: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java?rev=1709522&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java (added)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/DocumentClassificationTestBase.java Tue Oct 20 07:36:41 2015
@@ -0,0 +1,259 @@
+package org.apache.lucene.classification.document;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.en.EnglishAnalyzer;
+import org.apache.lucene.classification.ClassificationResult;
+import org.apache.lucene.classification.ClassificationTestBase;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.RandomIndexWriter;
+import org.apache.lucene.index.SlowCompositeReaderWrapper;
+import org.apache.lucene.util.BytesRef;
+import org.junit.Before;
+
+/**
+ * Base class for testing {@link org.apache.lucene.classification.Classifier}s
+ */
+public abstract class DocumentClassificationTestBase<T> extends ClassificationTestBase {
+
+ protected static final BytesRef VIDEOGAME_RESULT = new BytesRef("videogames");
+ protected static final BytesRef VIDEOGAME_ANALYZED_RESULT = new BytesRef("videogam");
+ protected static final BytesRef BATMAN_RESULT = new BytesRef("batman");
+
+ protected String titleFieldName = "title";
+ protected String authorFieldName = "author";
+
+ protected Analyzer analyzer;
+ protected Map<String, Analyzer> field2analyzer;
+ protected LeafReader leafReader;
+
+ @Before
+ public void init() throws IOException {
+ analyzer = new EnglishAnalyzer();
+ field2analyzer = new LinkedHashMap<>();
+ field2analyzer.put(textFieldName, analyzer);
+ field2analyzer.put(titleFieldName, analyzer);
+ field2analyzer.put(authorFieldName, analyzer);
+ leafReader = populateDocumentClassificationIndex(analyzer);
+ }
+
+ protected double checkCorrectDocumentClassification(DocumentClassifier<T> classifier, Document inputDoc, T expectedResult) throws Exception {
+ ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
+ assertNotNull(classificationResult.getAssignedClass());
+ assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
+ double score = classificationResult.getScore();
+ assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
+ return score;
+ }
+
+ protected LeafReader populateDocumentClassificationIndex(Analyzer analyzer) throws IOException {
+ indexWriter.close();
+ indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
+ indexWriter.commit();
+ String text;
+ String title;
+ String author;
+
+ Document doc = new Document();
+ title = "Video games are an economic business";
+ text = "Video games have become an art form and an industry. The video game industry is of increasing" +
+ " commercial importance, with growth driven particularly by the emerging Asian markets and mobile games." +
+ " As of 2015, video games generated sales of USD 74 billion annually worldwide, and were the third-largest" +
+ " segment in the U.S. entertainment market, behind broadcast and cable TV.";
+ author = "Ign";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(categoryFieldName, "videogames", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ indexWriter.addDocument(doc);
+
+ doc = new Document();
+ title = "Video games: the definition of fun on PC and consoles";
+ text = "A video game is an electronic game that involves human interaction with a user interface to generate" +
+ " visual feedback on a video device. The word video in video game traditionally referred to a raster display device," +
+ "[1] but it now implies any type of display device that can produce two- or three-dimensional images." +
+ " The electronic systems used to play video games are known as platforms; examples of these are personal" +
+ " computers and video game consoles. These platforms range from large mainframe computers to small handheld devices." +
+ " Specialized video games such as arcade games, while previously common, have gradually declined in use.";
+ author = "Ign";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(categoryFieldName, "videogames", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ indexWriter.addDocument(doc);
+
+ doc = new Document();
+ title = "Video games: the history across PC, consoles and fun";
+ text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
+ " from 1947âa device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
+ " and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
+ "Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
+ " dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
+ author = "Ign";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(categoryFieldName, "videogames", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ indexWriter.addDocument(doc);
+
+ doc = new Document();
+ title = "Video games: the history";
+ text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
+ " from 1947âa device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
+ " and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
+ "Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
+ " dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
+ author = "Ign";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(categoryFieldName, "videogames", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ indexWriter.addDocument(doc);
+
+ doc = new Document();
+ title = "Batman: Arkham Knight PC Benchmarks, For What They're Worth";
+ text = "Although I didnât spend much time playing Batman: Arkham Origins, I remember the game rather well after" +
+ " testing it on no less than 30 graphics cards and 20 CPUs. Arkham Origins appeared to take full advantage of" +
+ " Unreal Engine 3, it ran smoothly on affordable GPUs, though itâs worth remembering that Origins was developed " +
+ "for last-gen consoles.This week marked the arrival of Batman: Arkham Knight, the fourth entry in WBâs Batman:" +
+ " Arkham series and a direct sequel to 2013âs Arkham Origins 2011âs Arkham City." +
+ "Arkham Knight is also powered by Unreal Engine 3, but you can expect noticeably improved graphics, in part because" +
+ " the PlayStation 4 and Xbox One have replaced the PS3 and 360 as the lowest common denominator.";
+ author = "Rocksteady Studios";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(categoryFieldName, "batman", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ indexWriter.addDocument(doc);
+
+ doc = new Document();
+ title = "Face-Off: Batman: Arkham Knight, the Dark Knight returns!";
+ text = "Despite the drama surrounding the PC release leading to its subsequent withdrawal, there's a sense of success" +
+ " in the console space as PlayStation 4 owners, and indeed those on Xbox One, get a superb rendition of Batman:" +
+ " Arkham Knight. It's fair to say Rocksteady sized up each console's strengths well ahead of producing its first" +
+ " current-gen title, and it's paid off in one of the best Batman games we've seen in years.";
+ author = "Rocksteady Studios";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(categoryFieldName, "batman", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ indexWriter.addDocument(doc);
+
+ doc = new Document();
+ title = "Batman: Arkham Knight Having More Trouble, But This Time not in Gotham";
+ text = "As news began to break about the numerous issues affecting the PC version of Batman: Arkham Knight, players" +
+ " of the console version breathed a sigh of relief and got back to playing the game. Now players of the PlayStation" +
+ " 4 version are having problems of their own, albeit much less severe ones." +
+ "This time Batman will have a difficult time in Gotham.";
+ author = "Rocksteady Studios";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(categoryFieldName, "batman", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ indexWriter.addDocument(doc);
+
+ doc = new Document();
+ title = "Batman: Arkham Knight the new legend of Gotham";
+ text = "As news began to break about the numerous issues affecting the PC version of the game, players" +
+ " of the console version breathed a sigh of relief and got back to play. Now players of the PlayStation" +
+ " 4 version are having problems of their own, albeit much less severe ones.";
+ author = "Rocksteady Studios";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(categoryFieldName, "batman", ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ indexWriter.addDocument(doc);
+
+
+ doc = new Document();
+ text = "unlabeled doc";
+ doc.add(new Field(textFieldName, text, ft));
+ indexWriter.addDocument(doc);
+
+ indexWriter.commit();
+ return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
+ }
+
+ protected Document getVideoGameDocument() {
+ Document doc = new Document();
+ String title = "The new generation of PC and Console Video games";
+ String text = "Recently a lot of games have been released for the latest generations of consoles and personal computers." +
+ "One of them is Batman: Arkham Knight released recently on PS4, X-box and personal computer." +
+ "Another important video game that will be released in November is Assassin's Creed, a classic series that sees its new installement on Halloween." +
+ "Recently a lot of problems affected the Assassin's creed series but this time it should ran smoothly on affordable GPUs." +
+ "Players are waiting for the versions of their favourite video games and so do we.";
+ String author = "Ign";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ return doc;
+ }
+
+ protected Document getBatmanDocument() {
+ Document doc = new Document();
+ String title = "Batman: Arkham Knight new adventures for the super hero across Gotham, the Dark Knight has returned!";
+ String title2 = "I am a second title !";
+ String text = "This game is the electronic version of the famous super hero adventures.It involves the interaction with the open world" +
+ " of the city of Gotham. Finally the player will be able to have fun on its personal device." +
+ " The three-dimensional images of the game are stunning, because it uses the Unreal Engine 3." +
+ " The systems available are PS4, X-Box and personal computer." +
+ " Will the simulate missile that is going to be fired, success ?\" +\n" +
+ " Will this video game make the history" +
+ " Help you favourite super hero to defeat all his enemies. The Dark Knight has returned !";
+ String author = "Rocksteady Studios";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(titleFieldName, title2, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ return doc;
+ }
+
+ protected Document getBatmanAmbiguosDocument() {
+ Document doc = new Document();
+ String title = "Batman: Arkham Knight new adventures for the super hero across Gotham, the Dark Knight has returned! Batman will win !";
+ String text = "Early games used interactive electronic devices with various display formats. The earliest example is" +
+ " from 1947âa device was filed for a patent on 25 January 1947, by Thomas T. Goldsmith Jr. and Estle Ray Mann," +
+ " and issued on 14 December 1948, as U.S. Patent 2455992.[2]" +
+ "Inspired by radar display tech, it consisted of an analog device that allowed a user to control a vector-drawn" +
+ " dot on the screen to simulate a missile being fired at targets, which were drawings fixed to the screen.[3]";
+ String author = "Ign";
+ doc.add(new Field(textFieldName, text, ft));
+ doc.add(new Field(titleFieldName, title, ft));
+ doc.add(new Field(authorFieldName, author, ft));
+ doc.add(new Field(booleanFieldName, "false", ft));
+ return doc;
+ }
+}
Added: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java?rev=1709522&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java (added)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/KNearestNeighborDocumentClassifierTest.java Tue Oct 20 07:36:41 2015
@@ -0,0 +1,96 @@
+package org.apache.lucene.classification.document;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.util.BytesRef;
+import org.junit.Test;
+
+/**
+ * Tests for {@link org.apache.lucene.classification.KNearestNeighborClassifier}
+ */
+public class KNearestNeighborDocumentClassifierTest extends DocumentClassificationTestBase<BytesRef> {
+
+ @Test
+ public void testBasicDocumentClassification() throws Exception {
+ try {
+ Document videoGameDocument = getVideoGameDocument();
+ Document batmanDocument = getBatmanDocument();
+ checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
+ checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
+ // considering only the text we have wrong classification because the text was ambiguos on purpose
+ checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
+ checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
+
+ } finally {
+ if (leafReader != null) {
+ leafReader.close();
+ }
+ }
+ }
+
+ @Test
+ public void testBasicDocumentClassificationScore() throws Exception {
+ try {
+ Document videoGameDocument = getVideoGameDocument();
+ Document batmanDocument = getBatmanDocument();
+ double score1 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), videoGameDocument, VIDEOGAME_RESULT);
+ assertEquals(1.0,score1,0);
+ double score2 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), batmanDocument, BATMAN_RESULT);
+ assertEquals(1.0,score2,0);
+ // considering only the text we have wrong classification because the text was ambiguos on purpose
+ double score3 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), videoGameDocument, BATMAN_RESULT);
+ assertEquals(1.0,score3,0);
+ double score4 = checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName}), batmanDocument, VIDEOGAME_RESULT);
+ assertEquals(1.0,score4,0);
+ } finally {
+ if (leafReader != null) {
+ leafReader.close();
+ }
+ }
+ }
+
+ @Test
+ public void testBoostedDocumentClassification() throws Exception {
+ try {
+ checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName + "^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
+ // considering without boost wrong classification will appear
+ checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, null, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_RESULT);
+ } finally {
+ if (leafReader != null) {
+ leafReader.close();
+ }
+ }
+ }
+
+ @Test
+ public void testBasicDocumentClassificationWithQuery() throws Exception {
+ try {
+ TermQuery query = new TermQuery(new Term(authorFieldName, "ign"));
+ checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null, query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_RESULT);
+ checkCorrectDocumentClassification(new KNearestNeighborDocumentClassifier(leafReader,null,query, 1, 1, 1, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), VIDEOGAME_RESULT);
+ } finally {
+ if (leafReader != null) {
+ leafReader.close();
+ }
+ }
+ }
+
+}
Added: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java?rev=1709522&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java (added)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifierTest.java Tue Oct 20 07:36:41 2015
@@ -0,0 +1,76 @@
+package org.apache.lucene.classification.document;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.lucene.util.BytesRef;
+import org.junit.Test;
+
+/**
+ * Tests for {@link org.apache.lucene.classification.SimpleNaiveBayesClassifier}
+ */
+public class SimpleNaiveBayesDocumentClassifierTest extends DocumentClassificationTestBase<BytesRef> {
+
+ @Test
+ public void testBasicDocumentClassification() throws Exception {
+ try {
+ checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
+ checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
+
+ checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
+ checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
+ } finally {
+ if (leafReader != null) {
+ leafReader.close();
+ }
+ }
+ }
+
+ @Test
+ public void testBasicDocumentClassificationScore() throws Exception {
+ try {
+ double score1 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName, field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getVideoGameDocument(), VIDEOGAME_ANALYZED_RESULT);
+ assertEquals(0.88,score1,0.01);
+ double score2 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanDocument(), BATMAN_RESULT);
+ assertEquals(0.89,score2,0.01);
+ //taking in consideration only the text
+ double score3 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getVideoGameDocument(), BATMAN_RESULT);
+ assertEquals(0.55,score3,0.01);
+ double score4 = checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName}), getBatmanDocument(), VIDEOGAME_ANALYZED_RESULT);
+ assertEquals(0.52,score4,0.01);
+ } finally {
+ if (leafReader != null) {
+ leafReader.close();
+ }
+ }
+ }
+
+ @Test
+ public void testBoostedDocumentClassification() throws Exception {
+ try {
+ checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName+"^100", authorFieldName}), getBatmanAmbiguosDocument(), BATMAN_RESULT);
+ // considering without boost wrong classification will appear
+ checkCorrectDocumentClassification(new SimpleNaiveBayesDocumentClassifier(leafReader, null, categoryFieldName,field2analyzer, new String[]{textFieldName, titleFieldName, authorFieldName}), getBatmanAmbiguosDocument(), VIDEOGAME_ANALYZED_RESULT);
+ } finally {
+ if (leafReader != null) {
+ leafReader.close();
+ }
+ }
+ }
+
+
+}