You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2014/08/20 10:56:43 UTC

svn commit: r1619053 - in /lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification: BooleanPerceptronClassifier.java ClassificationResult.java Classifier.java KNearestNeighborClassifier.java SimpleNaiveBayesClassifier.java

Author: tommaso
Date: Wed Aug 20 08:56:42 2014
New Revision: 1619053

URL: http://svn.apache.org/r1619053
Log:
LUCENE-5699 - patch from Gergő Törcsvári for normalized score and return lists in classification

Modified:
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1619053&r1=1619052&r2=1619053&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java Wed Aug 20 08:56:42 2014
@@ -42,6 +42,7 @@ import org.apache.lucene.util.fst.Positi
 import org.apache.lucene.util.fst.Util;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Map;
 import java.util.SortedMap;
 import java.util.TreeMap;
@@ -131,9 +132,7 @@ public class BooleanPerceptronClassifier
     this.textTerms = MultiFields.getTerms(atomicReader, textFieldName);
 
     if (textTerms == null) {
-      throw new IOException(new StringBuilder(
-          "term vectors need to be available for field ").append(textFieldName)
-          .toString());
+      throw new IOException("term vectors need to be available for field " + textFieldName);
     }
 
     this.analyzer = analyzer;
@@ -246,4 +245,22 @@ public class BooleanPerceptronClassifier
     fst = fstBuilder.finish();
   }
 
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public List<ClassificationResult<BytesRef>> getClasses(String text)
+      throws IOException {
+    return null;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public List<ClassificationResult<BytesRef>> getClasses(String text, int max)
+      throws IOException {
+    return null;
+  }
+
 }
\ No newline at end of file

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java?rev=1619053&r1=1619052&r2=1619053&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java Wed Aug 20 08:56:42 2014
@@ -20,10 +20,10 @@ package org.apache.lucene.classification
  * The result of a call to {@link Classifier#assignClass(String)} holding an assigned class of type <code>T</code> and a score.
  * @lucene.experimental
  */
-public class ClassificationResult<T> {
+public class ClassificationResult<T> implements Comparable<ClassificationResult<T>>{
 
   private final T assignedClass;
-  private final double score;
+  private double score;
 
   /**
    * Constructor
@@ -50,4 +50,18 @@ public class ClassificationResult<T> {
   public double getScore() {
     return score;
   }
+  
+  /**
+   * set the score value
+   * @param score the score for the assignedClass as a <code>double</code>
+   */
+  public void setScore(double score) {
+    this.score = score;
+  }
+
+
+  @Override
+  public int compareTo(ClassificationResult<T> o) {
+    return this.getScore() < o.getScore() ? 1 : this.getScore() > o.getScore() ? -1 : 0;
+  }
 }

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java?rev=1619053&r1=1619052&r2=1619053&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java Wed Aug 20 08:56:42 2014
@@ -16,21 +16,25 @@
  */
 package org.apache.lucene.classification;
 
+import java.io.IOException;
+import java.util.List;
+
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.index.AtomicReader;
 import org.apache.lucene.search.Query;
-
-import java.io.IOException;
+import org.apache.lucene.util.BytesRef;
 
 /**
  * A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
  * <code>T</code>
+ *
  * @lucene.experimental
  */
 public interface Classifier<T> {
 
   /**
    * Assign a class (with score) to the given text String
+   *
    * @param text a String containing text to be classified
    * @return a {@link ClassificationResult} holding assigned class of type <code>T</code> and score
    * @throws IOException If there is a low-level I/O error.
@@ -38,11 +42,31 @@ public interface Classifier<T> {
   public ClassificationResult<T> assignClass(String text) throws IOException;
 
   /**
+   * Get all the classes (sorted by score, descending) assigned to the given text String.
+   *
+   * @param text a String containing text to be classified
+   * @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
+   * @throws IOException If there is a low-level I/O error.
+   */
+  public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException;
+
+  /**
+   * Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
+   *
+   * @param text a String containing text to be classified
+   * @param max  the number of return list elements
+   * @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
+   * @throws IOException If there is a low-level I/O error.
+   */
+  public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException;
+
+  /**
    * Train the classifier using the underlying Lucene index
-   * @param atomicReader the reader to use to access the Lucene index
-   * @param textFieldName the name of the field used to compare documents
+   *
+   * @param atomicReader   the reader to use to access the Lucene index
+   * @param textFieldName  the name of the field used to compare documents
    * @param classFieldName the name of the field containing the class assigned to documents
-   * @param analyzer the analyzer used to tokenize / filter the unseen text
+   * @param analyzer       the analyzer used to tokenize / filter the unseen text
    * @throws IOException If there is a low-level I/O error.
    */
   public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer)
@@ -50,11 +74,12 @@ public interface Classifier<T> {
 
   /**
    * Train the classifier using the underlying Lucene index
-   * @param atomicReader the reader to use to access the Lucene index
-   * @param textFieldName the name of the field used to compare documents
+   *
+   * @param atomicReader   the reader to use to access the Lucene index
+   * @param textFieldName  the name of the field used to compare documents
    * @param classFieldName the name of the field containing the class assigned to documents
-   * @param analyzer the analyzer used to tokenize / filter the unseen text
-   * @param query the query to filter which documents use for training
+   * @param analyzer       the analyzer used to tokenize / filter the unseen text
+   * @param query          the query to filter which documents use for training
    * @throws IOException If there is a low-level I/O error.
    */
   public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
@@ -62,11 +87,12 @@ public interface Classifier<T> {
 
   /**
    * Train the classifier using the underlying Lucene index
-   * @param atomicReader the reader to use to access the Lucene index
+   *
+   * @param atomicReader   the reader to use to access the Lucene index
    * @param textFieldNames the names of the fields to be used to compare documents
    * @param classFieldName the name of the field containing the class assigned to documents
-   * @param analyzer the analyzer used to tokenize / filter the unseen text
-   * @param query the query to filter which documents use for training
+   * @param analyzer       the analyzer used to tokenize / filter the unseen text
+   * @param query          the query to filter which documents use for training
    * @throws IOException If there is a low-level I/O error.
    */
   public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1619053&r1=1619052&r2=1619053&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java Wed Aug 20 08:56:42 2014
@@ -31,7 +31,10 @@ import org.apache.lucene.util.BytesRef;
 
 import java.io.IOException;
 import java.io.StringReader;
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 /**
@@ -79,6 +82,42 @@ public class KNearestNeighborClassifier 
    */
   @Override
   public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
+    TopDocs topDocs=knnSearcher(text);
+    List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
+    ClassificationResult<BytesRef> retval=null;
+    double maxscore=-Double.MAX_VALUE;
+    for(ClassificationResult<BytesRef> element:doclist){
+      if(element.getScore()>maxscore){
+        retval=element;
+        maxscore=element.getScore();
+      }
+    }
+    return retval;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
+    TopDocs topDocs=knnSearcher(text);
+    List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
+    Collections.sort(doclist);
+    return doclist;
+  }
+  
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
+    TopDocs topDocs=knnSearcher(text);
+    List<ClassificationResult<BytesRef>> doclist=buildListFromTopDocs(topDocs);
+    Collections.sort(doclist);
+    return doclist.subList(0, max);
+  }
+
+  private TopDocs knnSearcher(String text) throws IOException{
     if (mlt == null) {
       throw new IOException("You must first call Classifier#train");
     }
@@ -91,33 +130,36 @@ public class KNearestNeighborClassifier 
     if (query != null) {
       mltQuery.add(query, BooleanClause.Occur.MUST);
     }
-    TopDocs topDocs = indexSearcher.search(mltQuery, k);
-    return selectClassFromNeighbors(topDocs);
+    return indexSearcher.search(mltQuery, k);
   }
-
-  private ClassificationResult<BytesRef> selectClassFromNeighbors(TopDocs topDocs) throws IOException {
-    // TODO : improve the nearest neighbor selection
+  
+  private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
     Map<BytesRef, Integer> classCounts = new HashMap<>();
     for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
-      BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
-      Integer count = classCounts.get(cl);
-      if (count != null) {
-        classCounts.put(cl, count + 1);
-      } else {
-        classCounts.put(cl, 1);
-      }
+        BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
+        Integer count = classCounts.get(cl);
+        if (count != null) {
+            classCounts.put(cl, count + 1);
+        } else {
+            classCounts.put(cl, 1);
+        }
     }
-    double max = 0;
-    BytesRef assignedClass = new BytesRef();
+    List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
+    int sumdoc=0;
     for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
-      Integer count = entry.getValue();
-      if (count > max) {
-        max = count;
-        assignedClass = entry.getKey().clone();
+        Integer count = entry.getValue();
+        returnList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
+        sumdoc+=count;
+
+    }
+    
+    //correction
+    if(sumdoc<k){
+      for(ClassificationResult<BytesRef> cr:returnList){
+        cr.setScore(cr.getScore()*(double)k/(double)sumdoc);
       }
     }
-    double score = max / (double) k;
-    return new ClassificationResult<>(assignedClass, score);
+    return returnList;
   }
 
   /**

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1619053&r1=1619052&r2=1619053&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java Wed Aug 20 08:56:42 2014
@@ -16,6 +16,13 @@
  */
 package org.apache.lucene.classification;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.analysis.TokenStream;
 import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
@@ -33,10 +40,6 @@ import org.apache.lucene.search.TotalHit
 import org.apache.lucene.search.WildcardQuery;
 import org.apache.lucene.util.BytesRef;
 
-import java.io.IOException;
-import java.util.Collection;
-import java.util.LinkedList;
-
 /**
  * A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code>
  *
@@ -44,13 +47,12 @@ import java.util.LinkedList;
  */
 public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
 
-  private AtomicReader atomicReader;
-  private String[] textFieldNames;
-  private String classFieldName;
-  private int docsWithClassSize;
-  private Analyzer analyzer;
-  private IndexSearcher indexSearcher;
-  private Query query;
+  protected AtomicReader atomicReader;
+  protected String[] textFieldNames;
+  protected String classFieldName;
+  protected Analyzer analyzer;
+  protected IndexSearcher indexSearcher;
+  protected Query query;
 
   /**
    * Creates a new NaiveBayes classifier.
@@ -89,10 +91,88 @@ public class SimpleNaiveBayesClassifier 
     this.classFieldName = classFieldName;
     this.analyzer = analyzer;
     this.query = query;
-    this.docsWithClassSize = countDocsWithClass();
   }
 
-  private int countDocsWithClass() throws IOException {
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
+    List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(inputDocument);
+    ClassificationResult<BytesRef> retval = null;
+    double maxscore = -Double.MAX_VALUE;
+    for (ClassificationResult<BytesRef> element : doclist) {
+      if (element.getScore() > maxscore) {
+        retval = element;
+        maxscore = element.getScore();
+      }
+    }
+    return retval;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
+    List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
+    Collections.sort(doclist);
+    return doclist;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
+    List<ClassificationResult<BytesRef>> doclist = assignClassNormalizedList(text);
+    Collections.sort(doclist);
+    return doclist.subList(0, max);
+  }
+
+  private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
+    if (atomicReader == null) {
+      throw new IOException("You must first call Classifier#train");
+    }
+    List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();
+
+    Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
+    TermsEnum termsEnum = terms.iterator(null);
+    BytesRef next;
+    String[] tokenizedDoc = tokenizeDoc(inputDocument);
+    int docsWithClassSize = countDocsWithClass();
+    while ((next = termsEnum.next()) != null) {
+      double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc, next, docsWithClassSize);
+      dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal));
+    }
+
+    // normalization; the values transforms to a 0-1 range
+    ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
+    if (!dataList.isEmpty()) {
+      Collections.sort(dataList);
+      // this is a negative number closest to 0 = a
+      double smax = dataList.get(0).getScore();
+
+      double sumLog = 0;
+      // log(sum(exp(x_n-a)))
+      for (ClassificationResult<BytesRef> cr : dataList) {
+        // getScore-smax <=0 (both negative, smax is the smallest abs()
+        sumLog += Math.exp(cr.getScore() - smax);
+      }
+      // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
+      double loga = smax;
+      loga += Math.log(sumLog);
+
+      // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
+      for (ClassificationResult<BytesRef> cr : dataList) {
+        returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(cr.getScore() - loga)));
+      }
+    }
+
+    return returnList;
+  }
+
+  protected int countDocsWithClass() throws IOException {
     int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
     if (docCount == -1) { // in case codec doesn't support getDocCount
       TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
@@ -108,7 +188,7 @@ public class SimpleNaiveBayesClassifier 
     return docCount;
   }
 
-  private String[] tokenizeDoc(String doc) throws IOException {
+  protected String[] tokenizeDoc(String doc) throws IOException {
     Collection<String> result = new LinkedList<>();
     for (String textFieldName : textFieldNames) {
       try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, doc)) {
@@ -123,34 +203,7 @@ public class SimpleNaiveBayesClassifier 
     return result.toArray(new String[result.size()]);
   }
 
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
-    if (atomicReader == null) {
-      throw new IOException("You must first call Classifier#train");
-    }
-    double max = - Double.MAX_VALUE;
-    BytesRef foundClass = new BytesRef();
-
-    Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
-    TermsEnum termsEnum = terms.iterator(null);
-    BytesRef next;
-    String[] tokenizedDoc = tokenizeDoc(inputDocument);
-    while ((next = termsEnum.next()) != null) {
-      double clVal = calculateLogPrior(next) + calculateLogLikelihood(tokenizedDoc, next);
-      if (clVal > max) {
-        max = clVal;
-        foundClass = BytesRef.deepCopyOf(next);
-      }
-    }
-    double score = 10 / Math.abs(max);
-    return new ClassificationResult<>(foundClass, score);
-  }
-
-
-  private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c) throws IOException {
+  private double calculateLogLikelihood(String[] tokenizedDoc, BytesRef c, int docsWithClassSize) throws IOException {
     // for each word
     double result = 0d;
     for (String word : tokenizedDoc) {
@@ -187,7 +240,7 @@ public class SimpleNaiveBayesClassifier 
     BooleanQuery booleanQuery = new BooleanQuery();
     BooleanQuery subQuery = new BooleanQuery();
     for (String textFieldName : textFieldNames) {
-     subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
+      subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
     }
     booleanQuery.add(new BooleanClause(subQuery, BooleanClause.Occur.MUST));
     booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
@@ -199,7 +252,7 @@ public class SimpleNaiveBayesClassifier 
     return totalHitCountCollector.getTotalHits();
   }
 
-  private double calculateLogPrior(BytesRef currentClass) throws IOException {
+  private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
     return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
   }