You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2018/05/16 16:05:45 UTC

[5/6] lucene-solr:branch_7x: LUCENE-8315: Make FeatureField easier to use.

LUCENE-8315: Make FeatureField easier to use.


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/b5bfcf06
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/b5bfcf06
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/b5bfcf06

Branch: refs/heads/branch_7x
Commit: b5bfcf06bab7ca66e151fa292257e3902d8ccf8a
Parents: 5ad9f0d
Author: Adrien Grand <jp...@gmail.com>
Authored: Wed May 16 17:16:16 2018 +0200
Committer: Adrien Grand <jp...@gmail.com>
Committed: Wed May 16 18:01:19 2018 +0200

----------------------------------------------------------------------
 .../apache/lucene/document/FeatureField.java    | 87 +++++++++++++-------
 .../apache/lucene/document/FeatureQuery.java    | 21 ++++-
 .../lucene/document/TestFeatureField.java       | 38 +++++++--
 3 files changed, 109 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/b5bfcf06/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
----------------------------------------------------------------------
diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
index 10a7310..2f22308 100644
--- a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
+++ b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
@@ -17,18 +17,19 @@
 package org.apache.lucene.document;
 
 import java.io.IOException;
+import java.util.Objects;
 
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.analysis.TokenStream;
 import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
 import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
 import org.apache.lucene.index.IndexOptions;
+import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.index.TermContext;
 import org.apache.lucene.search.BooleanQuery;
 import org.apache.lucene.search.BoostQuery;
 import org.apache.lucene.search.Explanation;
-import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.similarities.BM25Similarity;
 import org.apache.lucene.search.similarities.Similarity.SimScorer;
@@ -83,7 +84,7 @@ import org.apache.lucene.util.BytesRef;
  * <p>
  * The constants in the above formulas typically need training in order to
  * compute optimal values. If you don't know where to start, the
- * {@link #newSaturationQuery(IndexSearcher, String, String)} method uses
+ * {@link #newSaturationQuery(String, String)} method uses
  * {@code 1f} as a weight and tries to guess a sensible value for the
  * {@code pivot} parameter of the saturation function based on index
  * statistics, which shouldn't perform too bad. Here is an example, assuming
@@ -94,7 +95,7 @@ import org.apache.lucene.util.BytesRef;
  *     .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
  *     .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
  *     .build();
- * Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
+ * Query boost = FeatureField.newSaturationQuery("features", "pagerank");
  * Query boostedQuery = new BooleanQuery.Builder()
  *     .add(query, Occur.MUST)
  *     .add(boost, Occur.SHOULD)
@@ -211,6 +212,7 @@ public final class FeatureField extends Field {
   static abstract class FeatureFunction {
     abstract SimScorer scorer(String field, float w);
     abstract Explanation explain(String field, String feature, float w, int doc, int freq) throws IOException;
+    FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
   }
 
   static final class LogFunction extends FeatureFunction {
@@ -274,24 +276,38 @@ public final class FeatureField extends Field {
 
   static final class SaturationFunction extends FeatureFunction {
 
-    private final float pivot;
+    private final String field, feature;
+    private final Float pivot;
 
-    SaturationFunction(float pivot) {
+    SaturationFunction(String field, String feature, Float pivot) {
+      this.field = field;
+      this.feature = feature;
       this.pivot = pivot;
     }
 
     @Override
+    public FeatureFunction rewrite(IndexReader reader) throws IOException {
+      if (pivot != null) {
+        return super.rewrite(reader);
+      }
+      float newPivot = computePivotFeatureValue(reader, field, feature);
+      return new SaturationFunction(field, feature, newPivot);
+    }
+
+    @Override
     public boolean equals(Object obj) {
       if (obj == null || getClass() != obj.getClass()) {
         return false;
       }
       SaturationFunction that = (SaturationFunction) obj;
-      return pivot == that.pivot;
+      return Objects.equals(field, that.field) &&
+          Objects.equals(feature, that.feature) &&
+          Objects.equals(pivot, that.pivot);
     }
 
     @Override
     public int hashCode() {
-      return Float.hashCode(pivot);
+      return Objects.hash(field, feature, pivot);
     }
 
     @Override
@@ -301,6 +317,10 @@ public final class FeatureField extends Field {
 
     @Override
     SimScorer scorer(String field, float weight) {
+      if (pivot == null) {
+        throw new IllegalStateException("Rewrite first");
+      }
+      final float pivot = this.pivot; // unbox
       return new SimScorer() {
         @Override
         public float score(int doc, float freq) {
@@ -447,13 +467,30 @@ public final class FeatureField extends Field {
    * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
    */
   public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) {
+    return newSaturationQuery(fieldName, featureName, weight, Float.valueOf(pivot));
+  }
+
+  /**
+   * Same as {@link #newSaturationQuery(String, String, float, float)} but
+   * {@code 1f} is used as a weight and a reasonably good default pivot value
+   * is computed based on index statistics and is approximately equal to the
+   * geometric mean of all values that exist in the index.
+   * @param fieldName   field that stores features
+   * @param featureName name of the feature
+   * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
+   */
+  public static Query newSaturationQuery(String fieldName, String featureName) {
+    return newSaturationQuery(fieldName, featureName, 1f, null);
+  }
+
+  private static Query newSaturationQuery(String fieldName, String featureName, float weight, Float pivot) {
     if (weight <= 0 || weight > MAX_WEIGHT) {
       throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
     }
-    if (pivot <= 0 || Float.isFinite(pivot) == false) {
+    if (pivot != null && (pivot <= 0 || Float.isFinite(pivot) == false)) {
       throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
     }
-    Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(pivot));
+    Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(fieldName, featureName, pivot));
     if (weight != 1f) {
       q = new BoostQuery(q, weight);
     }
@@ -461,25 +498,6 @@ public final class FeatureField extends Field {
   }
 
   /**
-   * Same as {@link #newSaturationQuery(String, String, float, float)} but
-   * uses {@code 1f} as a weight and tries to compute a sensible default value
-   * for {@code pivot} using
-   * {@link #computePivotFeatureValue(IndexSearcher, String, String)}. This
-   * isn't expected to give an optimal configuration of these parameters but
-   * should be a good start if you have no idea what the values of these
-   * parameters should be.
-   * @param searcher         the {@link IndexSearcher} that you will search against
-   * @param featureFieldName the field that stores features
-   * @param featureName      the name of the feature
-   */
-  public static Query newSaturationQuery(IndexSearcher searcher,
-      String featureFieldName, String featureName) throws IOException {
-    float weight = 1f;
-    float pivot = computePivotFeatureValue(searcher, featureFieldName, featureName);
-    return newSaturationQuery(featureFieldName, featureName, weight, pivot);
-  }
-
-  /**
    * Return a new {@link Query} that will score documents as
    * {@code weight * S^a / (S^a + pivot^a)} where S is the value of the static feature.
    * @param fieldName   field that stores features
@@ -514,13 +532,20 @@ public final class FeatureField extends Field {
    * representation in practice before converting it back to a float. Given that
    * floats store the exponent in the higher bits, it means that the result will
    * be an approximation of the geometric mean of all feature values.
-   * @param searcher     the {@link IndexSearcher} to search against
+   * @param reader       the {@link IndexReader} to search against
    * @param featureField the field that stores features
    * @param featureName  the name of the feature
    */
-  public static float computePivotFeatureValue(IndexSearcher searcher, String featureField, String featureName) throws IOException {
+  static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName) throws IOException {
     Term term = new Term(featureField, featureName);
-    TermContext context = TermContext.build(searcher.getIndexReader().getContext(), term);
+    TermContext context = TermContext.build(reader.getContext(), term);
+    if (context.docFreq() == 0) {
+      // avoid division by 0
+      // The return value doesn't matter much here, the term doesn't exist,
+      // it will never be used for scoring. Just Make sure to return a legal
+      // value.
+      return 1;
+    }
     float avgFreq = (float) ((double) context.totalTermFreq() / context.docFreq());
     return decodeFeatureValue(avgFreq);
   }

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/b5bfcf06/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
----------------------------------------------------------------------
diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
index eb71d05..fd9bc45 100644
--- a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
@@ -21,6 +21,7 @@ import java.util.Objects;
 import java.util.Set;
 
 import org.apache.lucene.document.FeatureField.FeatureFunction;
+import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.PostingsEnum;
 import org.apache.lucene.index.Term;
@@ -48,6 +49,15 @@ final class FeatureQuery extends Query {
   }
 
   @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    FeatureFunction rewritten = function.rewrite(reader);
+    if (function != rewritten) {
+      return new FeatureQuery(fieldName, featureName, rewritten);
+    }
+    return super.rewrite(reader);
+  }
+
+  @Override
   public boolean equals(Object obj) {
     if (obj == null || getClass() != obj.getClass()) {
       return false;
@@ -77,7 +87,16 @@ final class FeatureQuery extends Query {
       }
 
       @Override
-      public void extractTerms(Set<Term> terms) {}
+      public void extractTerms(Set<Term> terms) {
+        if (needsScores == false) {
+          // features are irrelevant to highlighting, skip
+        } else {
+          // extracting the term here will help get better scoring with
+          // distributed term statistics if the saturation function is used
+          // and the pivot value is computed automatically
+          terms.add(new Term(fieldName, featureName));
+        }
+      }
 
       @Override
       public Explanation explain(LeafReaderContext context, int doc) throws IOException {

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/b5bfcf06/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
----------------------------------------------------------------------
diff --git a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
index c15c226..88f5ede 100644
--- a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
+++ b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java
@@ -17,10 +17,15 @@
 package org.apache.lucene.document;
 
 import java.io.IOException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
 
 import org.apache.lucene.document.Field.Store;
 import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.MultiReader;
 import org.apache.lucene.index.RandomIndexWriter;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.BooleanClause.Occur;
@@ -209,7 +214,7 @@ public class TestFeatureField extends LuceneTestCase {
   }
 
   public void testSatuSimScorer() throws IOException {
-    doTestSimScorer(new FeatureField.SaturationFunction(20f).scorer("foo", 3f));
+    doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer("foo", 3f));
   }
 
   public void testSigmSimScorer() throws IOException {
@@ -229,6 +234,14 @@ public class TestFeatureField extends LuceneTestCase {
   public void testComputePivotFeatureValue() throws IOException {
     Directory dir = newDirectory();
     RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig());
+
+    // Make sure that we create a legal pivot on missing features
+    DirectoryReader reader = writer.getReader();
+    float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
+    assertTrue(Float.isFinite(pivot));
+    assertTrue(pivot > 0);
+    reader.close();
+
     Document doc = new Document();
     FeatureField pagerank = new FeatureField("features", "pagerank", 1);
     doc.add(pagerank);
@@ -247,11 +260,10 @@ public class TestFeatureField extends LuceneTestCase {
     pagerank.setFeatureValue(42);
     writer.addDocument(doc);
 
-    DirectoryReader reader = writer.getReader();
+    reader = writer.getReader();
     writer.close();
 
-    IndexSearcher searcher = new IndexSearcher(reader);
-    float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
+    pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
     double expected = Math.pow(10 * 100 * 1 * 42, 1/4.); // geometric mean
     assertEquals(expected, pivot, 0.1);
 
@@ -259,6 +271,22 @@ public class TestFeatureField extends LuceneTestCase {
     dir.close();
   }
 
+  public void testExtractTerms() throws IOException {
+    IndexReader reader = new MultiReader();
+    IndexSearcher searcher = newSearcher(reader);
+    Query query = FeatureField.newLogQuery("field", "term", 2f, 42);
+
+    Weight weight = searcher.createWeight(query, false, 1f);
+    Set<Term> terms = new HashSet<>();
+    weight.extractTerms(terms);
+    assertEquals(Collections.emptySet(), terms);
+
+    terms = new HashSet<>();
+    weight = searcher.createWeight(query, true, 1f);
+    weight.extractTerms(terms);
+    assertEquals(Collections.singleton(new Term("field", "term")), terms);
+  }
+
   public void testDemo() throws IOException {
     Directory dir = newDirectory();
     RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig()
@@ -297,7 +325,7 @@ public class TestFeatureField extends LuceneTestCase {
         .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
         .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
         .build();
-    Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
+    Query boost = FeatureField.newSaturationQuery("features", "pagerank");
     Query boostedQuery = new BooleanQuery.Builder()
         .add(query, Occur.MUST)
         .add(boost, Occur.SHOULD)