You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2012/09/14 08:55:12 UTC

svn commit: r1384657 - in /lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification: Classifier.java SimpleNaiveBayesClassifier.java

Author: tommaso
Date: Fri Sep 14 06:55:12 2012
New Revision: 1384657

URL: http://svn.apache.org/viewvc?rev=1384657&view=rev
Log:
[LUCENE-4345] - starting incorporating Simon's suggestions: using BytesRef and TotalHitCountCollector

Modified:
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java?rev=1384657&r1=1384656&r2=1384657&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java Fri Sep 14 06:55:12 2012
@@ -24,6 +24,7 @@ import java.io.IOException;
 
 /**
  * A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>
+ * @lucene.experimental
  */
 public interface Classifier {
 

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1384657&r1=1384656&r2=1384657&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java Fri Sep 14 06:55:12 2012
@@ -29,6 +29,7 @@ import org.apache.lucene.search.BooleanC
 import org.apache.lucene.search.BooleanQuery;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.TotalHitCountCollector;
 import org.apache.lucene.util.BytesRef;
 
 import java.io.IOException;
@@ -38,6 +39,7 @@ import java.util.LinkedList;
 
 /**
  * A simplistic Lucene based NaiveBayes classifier, see <code>http://en.wikipedia.org/wiki/Naive_Bayes_classifier</code>
+ * @lucene.experimental
  */
 public class SimpleNaiveBayesClassifier implements Classifier {
 
@@ -82,29 +84,27 @@ public class SimpleNaiveBayesClassifier 
     if (atomicReader == null) {
       throw new RuntimeException("need to train the classifier first");
     }
-    Double max = 0d;
+    double max = 0d;
     String foundClass = null;
 
     Terms terms = MultiFields.getTerms(atomicReader, classFieldName);
     TermsEnum termsEnum = terms.iterator(null);
-    BytesRef t = termsEnum.next();
-    while (t != null) {
-      String classValue = t.utf8ToString();
+    BytesRef next;
+    while((next = termsEnum.next()) != null) {
       // TODO : turn it to be in log scale
-      Double clVal = calculatePrior(classValue) * calculateLikelihood(inputDocument, classValue);
+      double clVal = calculatePrior(next) * calculateLikelihood(inputDocument, next);
       if (clVal > max) {
         max = clVal;
-        foundClass = classValue;
+        foundClass = next.utf8ToString();
       }
-      t = termsEnum.next();
     }
     return foundClass;
   }
 
 
-  private Double calculateLikelihood(String document, String c) throws IOException {
+  private double calculateLikelihood(String document, BytesRef c) throws IOException {
     // for each word
-    Double result = 1d;
+    double result = 1d;
     for (String word : tokenizeDoc(document)) {
       // search with text:word AND class:c
       int hits = getWordFreqForClass(word, c);
@@ -124,7 +124,7 @@ public class SimpleNaiveBayesClassifier 
     return result;
   }
 
-  private double getTextTermFreqForClass(String c) throws IOException {
+  private double getTextTermFreqForClass(BytesRef c) throws IOException {
     Terms terms = MultiFields.getTerms(atomicReader, textFieldName);
     long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
     double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
@@ -132,18 +132,20 @@ public class SimpleNaiveBayesClassifier 
     return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c
   }
 
-  private int getWordFreqForClass(String word, String c) throws IOException {
+  private int getWordFreqForClass(String word, BytesRef c) throws IOException {
     BooleanQuery booleanQuery = new BooleanQuery();
     booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST));
     booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
-    return indexSearcher.search(booleanQuery, 1).totalHits;
+    TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
+    indexSearcher.search(booleanQuery, totalHitCountCollector);
+    return totalHitCountCollector.getTotalHits();
   }
 
-  private Double calculatePrior(String currentClass) throws IOException {
+  private double calculatePrior(BytesRef currentClass) throws IOException {
     return (double) docCount(currentClass) / docsWithClassSize;
   }
 
-  private int docCount(String countedClass) throws IOException {
+  private int docCount(BytesRef countedClass) throws IOException {
     return atomicReader.docFreq(new Term(classFieldName, countedClass));
   }
 }