You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2018/05/16 16:05:45 UTC
[5/6] lucene-solr:branch_7x: LUCENE-8315: Make FeatureField easier to
use.
LUCENE-8315: Make FeatureField easier to use.
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/b5bfcf06
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/b5bfcf06
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/b5bfcf06
Branch: refs/heads/branch_7x
Commit: b5bfcf06bab7ca66e151fa292257e3902d8ccf8a
Parents: 5ad9f0d
Author: Adrien Grand <jp...@gmail.com>
Authored: Wed May 16 17:16:16 2018 +0200
Committer: Adrien Grand <jp...@gmail.com>
Committed: Wed May 16 18:01:19 2018 +0200
----------------------------------------------------------------------
.../apache/lucene/document/FeatureField.java | 87 +++++++++++++-------
.../apache/lucene/document/FeatureQuery.java | 21 ++++-
.../lucene/document/TestFeatureField.java | 38 +++++++--
3 files changed, 109 insertions(+), 37 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/b5bfcf06/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
----------------------------------------------------------------------
diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
index 10a7310..2f22308 100644
--- a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
+++ b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
@@ -17,18 +17,19 @@
package org.apache.lucene.document;
import java.io.IOException;
+import java.util.Objects;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
import org.apache.lucene.index.IndexOptions;
+import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermContext;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.Explanation;
-import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
@@ -83,7 +84,7 @@ import org.apache.lucene.util.BytesRef;
* <p>
* The constants in the above formulas typically need training in order to
* compute optimal values. If you don't know where to start, the
- * {@link #newSaturationQuery(IndexSearcher, String, String)} method uses
+ * {@link #newSaturationQuery(String, String)} method uses
* {@code 1f} as a weight and tries to guess a sensible value for the
* {@code pivot} parameter of the saturation function based on index
* statistics, which shouldn't perform too bad. Here is an example, assuming
@@ -94,7 +95,7 @@ import org.apache.lucene.util.BytesRef;
* .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
* .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
* .build();
- * Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
+ * Query boost = FeatureField.newSaturationQuery("features", "pagerank");
* Query boostedQuery = new BooleanQuery.Builder()
* .add(query, Occur.MUST)
* .add(boost, Occur.SHOULD)
@@ -211,6 +212,7 @@ public final class FeatureField extends Field {
static abstract class FeatureFunction {
abstract SimScorer scorer(String field, float w);
abstract Explanation explain(String field, String feature, float w, int doc, int freq) throws IOException;
+ FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
}
static final class LogFunction extends FeatureFunction {
@@ -274,24 +276,38 @@ public final class FeatureField extends Field {
static final class SaturationFunction extends FeatureFunction {
- private final float pivot;
+ private final String field, feature;
+ private final Float pivot;
- SaturationFunction(float pivot) {
+ SaturationFunction(String field, String feature, Float pivot) {
+ this.field = field;
+ this.feature = feature;
this.pivot = pivot;
}
@Override
+ public FeatureFunction rewrite(IndexReader reader) throws IOException {
+ if (pivot != null) {
+ return super.rewrite(reader);
+ }
+ float newPivot = computePivotFeatureValue(reader, field, feature);
+ return new SaturationFunction(field, feature, newPivot);
+ }
+
+ @Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
SaturationFunction that = (SaturationFunction) obj;
- return pivot == that.pivot;
+ return Objects.equals(field, that.field) &&
+ Objects.equals(feature, that.feature) &&
+ Objects.equals(pivot, that.pivot);
}
@Override
public int hashCode() {
- return Float.hashCode(pivot);
+ return Objects.hash(field, feature, pivot);
}
@Override
@@ -301,6 +317,10 @@ public final class FeatureField extends Field {
@Override
SimScorer scorer(String field, float weight) {
+ if (pivot == null) {
+ throw new IllegalStateException("Rewrite first");
+ }
+ final float pivot = this.pivot; // unbox
return new SimScorer() {
@Override
public float score(int doc, float freq) {
@@ -447,13 +467,30 @@ public final class FeatureField extends Field {
* @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
*/
public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) {
+ return newSaturationQuery(fieldName, featureName, weight, Float.valueOf(pivot));
+ }
+
+ /**
+ * Same as {@link #newSaturationQuery(String, String, float, float)} but
+ * {@code 1f} is used as a weight and a reasonably good default pivot value
+ * is computed based on index statistics and is approximately equal to the
+ * geometric mean of all values that exist in the index.
+ * @param fieldName field that stores features
+ * @param featureName name of the feature
+ * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
+ */
+ public static Query newSaturationQuery(String fieldName, String featureName) {
+ return newSaturationQuery(fieldName, featureName, 1f, null);
+ }
+
+ private static Query newSaturationQuery(String fieldName, String featureName, float weight, Float pivot) {
if (weight <= 0 || weight > MAX_WEIGHT) {
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
}
- if (pivot <= 0 || Float.isFinite(pivot) == false) {
+ if (pivot != null && (pivot <= 0 || Float.isFinite(pivot) == false)) {
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
}
- Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(pivot));
+ Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(fieldName, featureName, pivot));
if (weight != 1f) {
q = new BoostQuery(q, weight);
}
@@ -461,25 +498,6 @@ public final class FeatureField extends Field {
}
/**
- * Same as {@link #newSaturationQuery(String, String, float, float)} but
- * uses {@code 1f} as a weight and tries to compute a sensible default value
- * for {@code pivot} using
- * {@link #computePivotFeatureValue(IndexSearcher, String, String)}. This
- * isn't expected to give an optimal configuration of these parameters but
- * should be a good start if you have no idea what the values of these
- * parameters should be.
- * @param searcher the {@link IndexSearcher} that you will search against
- * @param featureFieldName the field that stores features
- * @param featureName the name of the feature
- */
- public static Query newSaturationQuery(IndexSearcher searcher,
- String featureFieldName, String featureName) throws IOException {
- float weight = 1f;
- float pivot = computePivotFeatureValue(searcher, featureFieldName, featureName);
- return newSaturationQuery(featureFieldName, featureName, weight, pivot);
- }
-
- /**
* Return a new {@link Query} that will score documents as
* {@code weight * S^a / (S^a + pivot^a)} where S is the value of the static feature.
* @param fieldName field that stores features
@@ -514,13 +532,20 @@ public final class FeatureField extends Field {
* representation in practice before converting it back to a float. Given that
* floats store the exponent in the higher bits, it means that the result will
* be an approximation of the geometric mean of all feature values.
- * @param searcher the {@link IndexSearcher} to search against
+ * @param reader the {@link IndexReader} to search against
* @param featureField the field that stores features
* @param featureName the name of the feature
*/
- public static float computePivotFeatureValue(IndexSearcher searcher, String featureField, String featureName) throws IOException {
+ static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName) throws IOException {
Term term = new Term(featureField, featureName);
- TermContext context = TermContext.build(searcher.getIndexReader().getContext(), term);
+ TermContext context = TermContext.build(reader.getContext(), term);
+ if (context.docFreq() == 0) {
+ // avoid division by 0
+ // The return value doesn't matter much here, the term doesn't exist,
+ // it will never be used for scoring. Just Make sure to return a legal
+ // value.
+ return 1;
+ }
float avgFreq = (float) ((double) context.totalTermFreq() / context.docFreq());
return decodeFeatureValue(avgFreq);
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/b5bfcf06/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
----------------------------------------------------------------------
diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
index eb71d05..fd9bc45 100644
--- a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
@@ -21,6 +21,7 @@ import java.util.Objects;
import java.util.Set;
import org.apache.lucene.document.FeatureField.FeatureFunction;
+import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
@@ -48,6 +49,15 @@ final class FeatureQuery extends Query {
}
@Override
+ public Query rewrite(IndexReader reader) throws IOException {
+ FeatureFunction rewritten = function.rewrite(reader);
+ if (function != rewritten) {
+ return new FeatureQuery(fieldName, featureName, rewritten);
+ }
+ return super.rewrite(reader);
+ }
+
+ @Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
@@ -77,7 +87,16 @@ final class FeatureQuery extends Query {
}
@Override
- public void extractTerms(Set<Term> terms) {}
+ public void extractTerms(Set<Term> terms) {
+ if (needsScores == false) {
+ // features are irrelevant to highlighting, skip
+ } else {
+ // extracting the term here will help get better scoring with
+ // distributed term statistics if the saturation function is used
+ // and the pivot value is computed automatically
+ terms.add(new Term(fieldName, featureName));
+ }
+ }
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/b5bfcf06/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
----------------------------------------------------------------------
diff --git a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
index c15c226..88f5ede 100644
--- a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
+++ b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
@@ -17,10 +17,15 @@
package org.apache.lucene.document;
import java.io.IOException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.MultiReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
@@ -209,7 +214,7 @@ public class TestFeatureField extends LuceneTestCase {
}
public void testSatuSimScorer() throws IOException {
- doTestSimScorer(new FeatureField.SaturationFunction(20f).scorer("foo", 3f));
+ doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer("foo", 3f));
}
public void testSigmSimScorer() throws IOException {
@@ -229,6 +234,14 @@ public class TestFeatureField extends LuceneTestCase {
public void testComputePivotFeatureValue() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig());
+
+ // Make sure that we create a legal pivot on missing features
+ DirectoryReader reader = writer.getReader();
+ float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
+ assertTrue(Float.isFinite(pivot));
+ assertTrue(pivot > 0);
+ reader.close();
+
Document doc = new Document();
FeatureField pagerank = new FeatureField("features", "pagerank", 1);
doc.add(pagerank);
@@ -247,11 +260,10 @@ public class TestFeatureField extends LuceneTestCase {
pagerank.setFeatureValue(42);
writer.addDocument(doc);
- DirectoryReader reader = writer.getReader();
+ reader = writer.getReader();
writer.close();
- IndexSearcher searcher = new IndexSearcher(reader);
- float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
+ pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
double expected = Math.pow(10 * 100 * 1 * 42, 1/4.); // geometric mean
assertEquals(expected, pivot, 0.1);
@@ -259,6 +271,22 @@ public class TestFeatureField extends LuceneTestCase {
dir.close();
}
+ public void testExtractTerms() throws IOException {
+ IndexReader reader = new MultiReader();
+ IndexSearcher searcher = newSearcher(reader);
+ Query query = FeatureField.newLogQuery("field", "term", 2f, 42);
+
+ Weight weight = searcher.createWeight(query, false, 1f);
+ Set<Term> terms = new HashSet<>();
+ weight.extractTerms(terms);
+ assertEquals(Collections.emptySet(), terms);
+
+ terms = new HashSet<>();
+ weight = searcher.createWeight(query, true, 1f);
+ weight.extractTerms(terms);
+ assertEquals(Collections.singleton(new Term("field", "term")), terms);
+ }
+
public void testDemo() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig()
@@ -297,7 +325,7 @@ public class TestFeatureField extends LuceneTestCase {
.add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
.add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
.build();
- Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
+ Query boost = FeatureField.newSaturationQuery("features", "pagerank");
Query boostedQuery = new BooleanQuery.Builder()
.add(query, Occur.MUST)
.add(boost, Occur.SHOULD)