You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2014/11/12 09:38:06 UTC
svn commit: r1638715 - in /lucene/dev/trunk/lucene/classification/src:
java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
test/org/apache/lucene/classification/ClassificationTestBase.java
Author: tommaso
Date: Wed Nov 12 08:38:06 2014
New Revision: 1638715
URL: http://svn.apache.org/r1638715
Log:
LUCENE-5699 - normalized score for boolean perceptron classifier
Modified:
lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1638715&r1=1638714&r2=1638715&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java Wed Nov 12 08:38:06 2014
@@ -16,6 +16,12 @@
*/
package org.apache.lucene.classification;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.concurrent.ConcurrentSkipListMap;
+
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
@@ -40,12 +46,6 @@ import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;
-import java.io.IOException;
-import java.util.List;
-import java.util.Map;
-import java.util.SortedMap;
-import java.util.TreeMap;
-
/**
* A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
* <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
@@ -53,7 +53,7 @@ import java.util.TreeMap;
* {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field
* and a per document basis and then a corresponding
* {@link org.apache.lucene.util.fst.FST} is used for class assignment.
- *
+ *
* @lucene.experimental
*/
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
@@ -67,9 +67,8 @@ public class BooleanPerceptronClassifier
/**
* Create a {@link BooleanPerceptronClassifier}
- *
- * @param threshold
- * the binary threshold for perceptron output evaluation
+ *
+ * @param threshold the binary threshold for perceptron output evaluation
*/
public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
this.threshold = threshold;
@@ -98,7 +97,7 @@ public class BooleanPerceptronClassifier
Long output = 0l;
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream
- .addAttribute(CharTermAttribute.class);
+ .addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
String s = charTermAttribute.toString();
@@ -110,7 +109,8 @@ public class BooleanPerceptronClassifier
tokenStream.end();
}
- return new ClassificationResult<>(output >= threshold, output.doubleValue());
+ double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
+ return new ClassificationResult<>(output >= threshold, score);
}
/**
@@ -127,7 +127,7 @@ public class BooleanPerceptronClassifier
*/
@Override
public void train(LeafReader leafReader, String textFieldName,
- String classFieldName, Analyzer analyzer, Query query) throws IOException {
+ String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
if (textTerms == null) {
@@ -150,7 +150,7 @@ public class BooleanPerceptronClassifier
}
// TODO : remove this map as soon as we have a writable FST
- SortedMap<String,Double> weights = new TreeMap<>();
+ SortedMap<String, Double> weights = new ConcurrentSkipListMap<>();
TermsEnum reuse = textTerms.iterator(null);
BytesRef textTerm;
@@ -177,10 +177,10 @@ public class BooleanPerceptronClassifier
ClassificationResult<Boolean> classificationResult = assignClass(doc
.getField(textFieldName).stringValue());
Boolean assignedClass = classificationResult.getAssignedClass();
-
+
// get the expected result
StorableField field = doc.getField(classFieldName);
-
+
Boolean correctClass = Boolean.valueOf(field.stringValue());
long modifier = correctClass.compareTo(assignedClass);
if (modifier != 0) {
@@ -198,8 +198,8 @@ public class BooleanPerceptronClassifier
}
private TermsEnum updateWeights(LeafReader leafReader, TermsEnum reuse,
- int docId, Boolean assignedClass, SortedMap<String,Double> weights,
- double modifier, boolean updateFST) throws IOException {
+ int docId, Boolean assignedClass, SortedMap<String, Double> weights,
+ double modifier, boolean updateFST) throws IOException {
TermsEnum cte = textTerms.iterator(reuse);
// get the doc term vectors
@@ -231,12 +231,12 @@ public class BooleanPerceptronClassifier
return reuse;
}
- private void updateFST(SortedMap<String,Double> weights) throws IOException {
+ private void updateFST(SortedMap<String, Double> weights) throws IOException {
PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
Builder<Long> fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
BytesRefBuilder scratchBytes = new BytesRefBuilder();
IntsRefBuilder scratchInts = new IntsRefBuilder();
- for (Map.Entry<String,Double> entry : weights.entrySet()) {
+ for (Map.Entry<String, Double> entry : weights.entrySet()) {
scratchBytes.copyChars(entry.getKey());
fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
.getValue().longValue());
Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1638715&r1=1638714&r2=1638715&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java Wed Nov 12 08:38:06 2014
@@ -91,7 +91,8 @@ public abstract class ClassificationTest
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
- assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
+ double score = classificationResult.getScore();
+ assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
} finally {
if (leafReader != null)
leafReader.close();
@@ -110,11 +111,12 @@ public abstract class ClassificationTest
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
- assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
+ double score = classificationResult.getScore();
+ assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
updateSampleIndex(analyzer);
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
- assertEquals(Double.valueOf(classificationResult.getScore()), Double.valueOf(secondClassificationResult.getScore()));
+ assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
} finally {
if (leafReader != null)