You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2015/05/04 12:06:32 UTC

svn commit: r1677573 - in /lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification: ./ utils/

Author: tommaso
Date: Mon May  4 10:06:32 2015
New Revision: 1677573

URL: http://svn.apache.org/r1677573
Log:
LUCENE-6045 - immutable ClassificationResult, minor fixes

Modified:
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1677573&r1=1677572&r2=1677573&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java Mon May  4 10:06:32 2015
@@ -226,7 +226,7 @@ public class BooleanPerceptronClassifier
   @Override
   public List<ClassificationResult<Boolean>> getClasses(String text)
           throws IOException {
-    throw new RuntimeException("not implemented");
+    return null;
   }
 
   /**
@@ -235,7 +235,7 @@ public class BooleanPerceptronClassifier
   @Override
   public List<ClassificationResult<Boolean>> getClasses(String text, int max)
           throws IOException {
-    throw new RuntimeException("not implemented");
+    return null;
   }
 
 }

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java?rev=1677573&r1=1677572&r2=1677573&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java Mon May  4 10:06:32 2015
@@ -141,12 +141,22 @@ public class CachingNaiveBayesClassifier
         double wordProbability = num / den;
 
         // modify the value in the result list item
+        int removeIdx = -1;
+        int i = 0;
         for (ClassificationResult<BytesRef> cr : ret) {
           if (cr.getAssignedClass().equals(cclass)) {
-            cr.setScore(cr.getScore() + Math.log(wordProbability));
+            removeIdx = i;
             break;
           }
+          i++;
         }
+
+        if (removeIdx >= 0) {
+          ClassificationResult<BytesRef> toRemove = ret.get(removeIdx);
+          ret.add(new ClassificationResult<>(toRemove.getAssignedClass(), toRemove.getScore() + Math.log(wordProbability)));
+          ret.remove(removeIdx);
+        }
+
       }
     }
 

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java?rev=1677573&r1=1677572&r2=1677573&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java Mon May  4 10:06:32 2015
@@ -24,7 +24,7 @@ package org.apache.lucene.classification
 public class ClassificationResult<T> implements Comparable<ClassificationResult<T>> {
 
   private final T assignedClass;
-  private double score;
+  private final double score;
 
   /**
    * Constructor
@@ -55,16 +55,6 @@ public class ClassificationResult<T> imp
     return score;
   }
 
-  /**
-   * set the score value
-   *
-   * @param score the score for the assignedClass as a <code>double</code>
-   */
-  public void setScore(double score) {
-    this.score = score;
-  }
-
-
   @Override
   public int compareTo(ClassificationResult<T> o) {
     return this.getScore() < o.getScore() ? 1 : this.getScore() > o.getScore() ? -1 : 0;

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1677573&r1=1677572&r2=1677573&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java Mon May  4 10:06:32 2015
@@ -153,18 +153,21 @@ public class KNearestNeighborClassifier
       }
     }
     List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
+    List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
     int sumdoc = 0;
     for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
       Integer count = entry.getValue();
-      returnList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
+      temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), count / (double) k));
       sumdoc += count;
     }
 
     //correction
     if (sumdoc < k) {
-      for (ClassificationResult<BytesRef> cr : returnList) {
-        cr.setScore(cr.getScore() * (double) k / (double) sumdoc);
+      for (ClassificationResult<BytesRef> cr : temporaryList) {
+        returnList.add(new ClassificationResult<>(cr.getAssignedClass(), cr.getScore() * k / (double) sumdoc));
       }
+    } else {
+      returnList = temporaryList;
     }
     return returnList;
   }

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java?rev=1677573&r1=1677572&r2=1677573&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java Mon May  4 10:06:32 2015
@@ -28,6 +28,7 @@ import org.apache.lucene.index.IndexWrit
 import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.StorableField;
+import org.apache.lucene.index.StoredDocument;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.MatchAllDocsQuery;
 import org.apache.lucene.search.ScoreDoc;
@@ -91,12 +92,16 @@ public class DatasetSplitter {
 
         // create a new document for indexing
         Document doc = new Document();
+        StoredDocument document = originalIndex.document(scoreDoc.doc);
         if (fieldNames != null && fieldNames.length > 0) {
           for (String fieldName : fieldNames) {
-            doc.add(new Field(fieldName, originalIndex.document(scoreDoc.doc).getField(fieldName).stringValue(), ft));
+            StorableField field = document.getField(fieldName);
+            if (field != null) {
+              doc.add(new Field(fieldName, field.stringValue(), ft));
+            }
           }
         } else {
-          for (StorableField storableField : originalIndex.document(scoreDoc.doc).getFields()) {
+          for (StorableField storableField : document.getFields()) {
             if (storableField.readerValue() != null) {
               doc.add(new Field(storableField.name(), storableField.readerValue(), ft));
             } else if (storableField.binaryValue() != null) {