You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2014/11/12 09:38:06 UTC

svn commit: r1638715 - in /lucene/dev/trunk/lucene/classification/src: java/org/apache/lucene/classification/BooleanPerceptronClassifier.java test/org/apache/lucene/classification/ClassificationTestBase.java

Author: tommaso
Date: Wed Nov 12 08:38:06 2014
New Revision: 1638715

URL: http://svn.apache.org/r1638715
Log:
LUCENE-5699 - normalized score for boolean perceptron classifier

Modified:
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1638715&r1=1638714&r2=1638715&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java Wed Nov 12 08:38:06 2014
@@ -16,6 +16,12 @@
  */
 package org.apache.lucene.classification;
 
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.concurrent.ConcurrentSkipListMap;
+
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.analysis.TokenStream;
 import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
@@ -40,12 +46,6 @@ import org.apache.lucene.util.fst.FST;
 import org.apache.lucene.util.fst.PositiveIntOutputs;
 import org.apache.lucene.util.fst.Util;
 
-import java.io.IOException;
-import java.util.List;
-import java.util.Map;
-import java.util.SortedMap;
-import java.util.TreeMap;
-
 /**
  * A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
  * <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
@@ -53,7 +53,7 @@ import java.util.TreeMap;
  * {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field
  * and a per document basis and then a corresponding
  * {@link org.apache.lucene.util.fst.FST} is used for class assignment.
- * 
+ *
  * @lucene.experimental
  */
 public class BooleanPerceptronClassifier implements Classifier<Boolean> {
@@ -67,9 +67,8 @@ public class BooleanPerceptronClassifier
 
   /**
    * Create a {@link BooleanPerceptronClassifier}
-   * 
-   * @param threshold
-   *          the binary threshold for perceptron output evaluation
+   *
+   * @param threshold the binary threshold for perceptron output evaluation
    */
   public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
     this.threshold = threshold;
@@ -98,7 +97,7 @@ public class BooleanPerceptronClassifier
     Long output = 0l;
     try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
       CharTermAttribute charTermAttribute = tokenStream
-        .addAttribute(CharTermAttribute.class);
+          .addAttribute(CharTermAttribute.class);
       tokenStream.reset();
       while (tokenStream.incrementToken()) {
         String s = charTermAttribute.toString();
@@ -110,7 +109,8 @@ public class BooleanPerceptronClassifier
       tokenStream.end();
     }
 
-    return new ClassificationResult<>(output >= threshold, output.doubleValue());
+    double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
+    return new ClassificationResult<>(output >= threshold, score);
   }
 
   /**
@@ -127,7 +127,7 @@ public class BooleanPerceptronClassifier
    */
   @Override
   public void train(LeafReader leafReader, String textFieldName,
-      String classFieldName, Analyzer analyzer, Query query) throws IOException {
+                    String classFieldName, Analyzer analyzer, Query query) throws IOException {
     this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
 
     if (textTerms == null) {
@@ -150,7 +150,7 @@ public class BooleanPerceptronClassifier
     }
 
     // TODO : remove this map as soon as we have a writable FST
-    SortedMap<String,Double> weights = new TreeMap<>();
+    SortedMap<String, Double> weights = new ConcurrentSkipListMap<>();
 
     TermsEnum reuse = textTerms.iterator(null);
     BytesRef textTerm;
@@ -177,10 +177,10 @@ public class BooleanPerceptronClassifier
       ClassificationResult<Boolean> classificationResult = assignClass(doc
           .getField(textFieldName).stringValue());
       Boolean assignedClass = classificationResult.getAssignedClass();
-      
+
       // get the expected result
       StorableField field = doc.getField(classFieldName);
-      
+
       Boolean correctClass = Boolean.valueOf(field.stringValue());
       long modifier = correctClass.compareTo(assignedClass);
       if (modifier != 0) {
@@ -198,8 +198,8 @@ public class BooleanPerceptronClassifier
   }
 
   private TermsEnum updateWeights(LeafReader leafReader, TermsEnum reuse,
-      int docId, Boolean assignedClass, SortedMap<String,Double> weights,
-      double modifier, boolean updateFST) throws IOException {
+                                  int docId, Boolean assignedClass, SortedMap<String, Double> weights,
+                                  double modifier, boolean updateFST) throws IOException {
     TermsEnum cte = textTerms.iterator(reuse);
 
     // get the doc term vectors
@@ -231,12 +231,12 @@ public class BooleanPerceptronClassifier
     return reuse;
   }
 
-  private void updateFST(SortedMap<String,Double> weights) throws IOException {
+  private void updateFST(SortedMap<String, Double> weights) throws IOException {
     PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
     Builder<Long> fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
     BytesRefBuilder scratchBytes = new BytesRefBuilder();
     IntsRefBuilder scratchInts = new IntsRefBuilder();
-    for (Map.Entry<String,Double> entry : weights.entrySet()) {
+    for (Map.Entry<String, Double> entry : weights.entrySet()) {
       scratchBytes.copyChars(entry.getKey());
       fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
           .getValue().longValue());

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1638715&r1=1638714&r2=1638715&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java Wed Nov 12 08:38:06 2014
@@ -91,7 +91,8 @@ public abstract class ClassificationTest
       ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
       assertNotNull(classificationResult.getAssignedClass());
       assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
-      assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
+      double score = classificationResult.getScore();
+      assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
     } finally {
       if (leafReader != null)
         leafReader.close();
@@ -110,11 +111,12 @@ public abstract class ClassificationTest
       ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
       assertNotNull(classificationResult.getAssignedClass());
       assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
-      assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
+      double score = classificationResult.getScore();
+      assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
       updateSampleIndex(analyzer);
       ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
       assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
-      assertEquals(Double.valueOf(classificationResult.getScore()), Double.valueOf(secondClassificationResult.getScore()));
+      assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
 
     } finally {
       if (leafReader != null)