You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ju...@apache.org on 2021/02/04 20:43:04 UTC

[lucene-solr] branch master updated: LUCENE-9725: Allow BM25FQuery to use other similarities. (#2293)

This is an automated email from the ASF dual-hosted git repository.

julietibs pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git


The following commit(s) were added to refs/heads/master by this push:
     new c3f5454  LUCENE-9725: Allow BM25FQuery to use other similarities. (#2293)
c3f5454 is described below

commit c3f5454d4903897211021eed0c824cb3797d85d9
Author: Julie Tibshirani <ju...@elastic.co>
AuthorDate: Thu Feb 4 12:42:45 2021 -0800

    LUCENE-9725: Allow BM25FQuery to use other similarities. (#2293)
    
    From a high level, BM25FQuery (1) computes statistic that represent the combined
    field content and (2) passes these to a score function. This model makes sense
    for many similarities besides BM25.
    
    This PR unhardcodes BM25Similarity in BM25FQuery and instead uses the one
    configured on IndexSearcher. It also renames BM25FQuery since it's no longer
    specific to BM25.
---
 .../{BM25FQuery.java => CombinedFieldQuery.java}   |  91 ++++++-------
 .../sandbox/search/MultiNormsLeafSimScorer.java    |   2 +-
 ...BM25FQuery.java => TestCombinedFieldQuery.java} | 145 ++++++++++++++++-----
 3 files changed, 154 insertions(+), 84 deletions(-)

diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/BM25FQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java
similarity index 85%
rename from lucene/sandbox/src/java/org/apache/lucene/sandbox/search/BM25FQuery.java
rename to lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java
index 827dc85..1c79657 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/BM25FQuery.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java
@@ -54,49 +54,49 @@ import org.apache.lucene.search.TermScorer;
 import org.apache.lucene.search.TermStatistics;
 import org.apache.lucene.search.Weight;
 import org.apache.lucene.search.similarities.BM25Similarity;
+import org.apache.lucene.search.similarities.DFRSimilarity;
 import org.apache.lucene.search.similarities.Similarity;
 import org.apache.lucene.search.similarities.SimilarityBase;
 import org.apache.lucene.util.Accountable;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.RamUsageEstimator;
+import org.apache.lucene.util.SmallFloat;
 
 /**
  * A {@link Query} that treats multiple fields as a single stream and scores terms as if you had
  * indexed them as a single term in a single field.
  *
- * <p>For scoring purposes this query implements the BM25F's simple formula described in:
- * http://www.staff.city.ac.uk/~sb317/papers/foundations_bm25_review.pdf
+ * <p>The query works as follows:
  *
- * <p>The per-field similarity is ignored but to be compatible each field must use a {@link
- * Similarity} at index time that encodes norms the same way as {@link SimilarityBase#computeNorm}.
+ * <ol>
+ *   <li>Given a list of fields and weights, it pretends there is a synthetic combined field where
+ *       all terms have been indexed. It computes new term and collection statistics for this
+ *       combined field.
+ *   <li>It uses a disjunction iterator and {@link IndexSearcher#getSimilarity} to score documents.
+ * </ol>
+ *
+ * <p>In order for a similarity to be compatible, {@link Similarity#computeNorm} must be additive:
+ * the norm of the combined field is the sum of norms for each individual field. The norms must also
+ * be encoded using {@link SmallFloat#intToByte4}. These requirements hold for all similarities that
+ * compute norms the same way as {@link SimilarityBase#computeNorm}, which includes {@link
+ * BM25Similarity} and {@link DFRSimilarity}. Per-field similarities are not supported.
+ *
+ * <p>The scoring is based on BM25F's simple formula described in:
+ * http://www.staff.city.ac.uk/~sb317/papers/foundations_bm25_review.pdf. This query implements the
+ * same approach but allows other similarities besides {@link
+ * org.apache.lucene.search.similarities.BM25Similarity}.
  *
  * @lucene.experimental
  */
-public final class BM25FQuery extends Query implements Accountable {
+public final class CombinedFieldQuery extends Query implements Accountable {
   private static final long BASE_RAM_BYTES =
-      RamUsageEstimator.shallowSizeOfInstance(BM25FQuery.class);
+      RamUsageEstimator.shallowSizeOfInstance(CombinedFieldQuery.class);
 
-  /** A builder for {@link BM25FQuery}. */
+  /** A builder for {@link CombinedFieldQuery}. */
   public static class Builder {
-    private final BM25Similarity similarity;
     private final Map<String, FieldAndWeight> fieldAndWeights = new HashMap<>();
     private final Set<BytesRef> termsSet = new HashSet<>();
 
-    /** Default builder. */
-    public Builder() {
-      this.similarity = new BM25Similarity();
-    }
-
-    /**
-     * Builder with the supplied parameter values.
-     *
-     * @param k1 Controls non-linear term frequency normalization (saturation).
-     * @param b Controls to what degree document length normalizes tf values.
-     */
-    public Builder(float k1, float b) {
-      this.similarity = new BM25Similarity(k1, b);
-    }
-
     /**
      * Adds a field to this builder.
      *
@@ -129,14 +129,14 @@ public final class BM25FQuery extends Query implements Accountable {
       return this;
     }
 
-    /** Builds the {@link BM25FQuery}. */
-    public BM25FQuery build() {
+    /** Builds the {@link CombinedFieldQuery}. */
+    public CombinedFieldQuery build() {
       int size = fieldAndWeights.size() * termsSet.size();
       if (size > IndexSearcher.getMaxClauseCount()) {
         throw new IndexSearcher.TooManyClauses();
       }
       BytesRef[] terms = termsSet.toArray(new BytesRef[0]);
-      return new BM25FQuery(similarity, new TreeMap<>(fieldAndWeights), terms);
+      return new CombinedFieldQuery(new TreeMap<>(fieldAndWeights), terms);
     }
   }
 
@@ -150,8 +150,6 @@ public final class BM25FQuery extends Query implements Accountable {
     }
   }
 
-  // the similarity to use for scoring.
-  private final BM25Similarity similarity;
   // sorted map for fields.
   private final TreeMap<String, FieldAndWeight> fieldAndWeights;
   // array of terms, sorted.
@@ -161,11 +159,7 @@ public final class BM25FQuery extends Query implements Accountable {
 
   private final long ramBytesUsed;
 
-  private BM25FQuery(
-      BM25Similarity similarity,
-      TreeMap<String, FieldAndWeight> fieldAndWeights,
-      BytesRef[] terms) {
-    this.similarity = similarity;
+  private CombinedFieldQuery(TreeMap<String, FieldAndWeight> fieldAndWeights, BytesRef[] terms) {
     this.fieldAndWeights = fieldAndWeights;
     this.terms = terms;
     int numFieldTerms = fieldAndWeights.size() * terms.length;
@@ -194,7 +188,7 @@ public final class BM25FQuery extends Query implements Accountable {
 
   @Override
   public String toString(String field) {
-    StringBuilder builder = new StringBuilder("BM25F((");
+    StringBuilder builder = new StringBuilder("CombinedFieldQuery((");
     int pos = 0;
     for (FieldAndWeight fieldWeight : fieldAndWeights.values()) {
       if (pos++ != 0) {
@@ -225,7 +219,7 @@ public final class BM25FQuery extends Query implements Accountable {
 
   @Override
   public boolean equals(Object other) {
-    return sameClassAs(other) && Arrays.equals(terms, ((BM25FQuery) other).terms);
+    return sameClassAs(other) && Arrays.equals(terms, ((CombinedFieldQuery) other).terms);
   }
 
   @Override
@@ -277,7 +271,7 @@ public final class BM25FQuery extends Query implements Accountable {
   public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
       throws IOException {
     if (scoreMode.needsScores()) {
-      return new BM25FWeight(this, searcher, scoreMode, boost);
+      return new CombinedFieldWeight(this, searcher, scoreMode, boost);
     } else {
       // rewrite to a simple disjunction if the score is not needed.
       Query bq = rewriteToBoolean();
@@ -285,12 +279,12 @@ public final class BM25FQuery extends Query implements Accountable {
     }
   }
 
-  class BM25FWeight extends Weight {
+  class CombinedFieldWeight extends Weight {
     private final IndexSearcher searcher;
     private final TermStates termStates[];
     private final Similarity.SimScorer simWeight;
 
-    BM25FWeight(Query query, IndexSearcher searcher, ScoreMode scoreMode, float boost)
+    CombinedFieldWeight(Query query, IndexSearcher searcher, ScoreMode scoreMode, float boost)
         throws IOException {
       super(query);
       assert scoreMode.needsScores();
@@ -313,7 +307,8 @@ public final class BM25FQuery extends Query implements Accountable {
         CollectionStatistics pseudoCollectionStats = mergeCollectionStatistics(searcher);
         TermStatistics pseudoTermStatistics =
             new TermStatistics(new BytesRef("pseudo_term"), docFreq, Math.max(1, totalTermFreq));
-        this.simWeight = similarity.scorer(boost, pseudoCollectionStats, pseudoTermStatistics);
+        this.simWeight =
+            searcher.getSimilarity().scorer(boost, pseudoCollectionStats, pseudoTermStatistics);
       } else {
         this.simWeight = null;
       }
@@ -352,8 +347,8 @@ public final class BM25FQuery extends Query implements Accountable {
         int newDoc = scorer.iterator().advance(doc);
         if (newDoc == doc) {
           final float freq;
-          if (scorer instanceof BM25FScorer) {
-            freq = ((BM25FScorer) scorer).freq();
+          if (scorer instanceof CombinedFieldScorer) {
+            freq = ((CombinedFieldScorer) scorer).freq();
           } else {
             assert scorer instanceof TermScorer;
             freq = ((TermScorer) scorer).freq();
@@ -365,13 +360,7 @@ public final class BM25FQuery extends Query implements Accountable {
           Explanation scoreExplanation = docScorer.explain(doc, freqExplanation);
           return Explanation.match(
               scoreExplanation.getValue(),
-              "weight("
-                  + getQuery()
-                  + " in "
-                  + doc
-                  + ") ["
-                  + similarity.getClass().getSimpleName()
-                  + "], result of:",
+              "weight(" + getQuery() + " in " + doc + "), result of:",
               scoreExplanation);
         }
       }
@@ -418,7 +407,7 @@ public final class BM25FQuery extends Query implements Accountable {
       // Even though it is called approximation, it is accurate since none of
       // the sub iterators are two-phase iterators.
       DocIdSetIterator iterator = new DisjunctionDISIApproximation(queue);
-      return new BM25FScorer(this, queue, iterator, scoringSimScorer);
+      return new CombinedFieldScorer(this, queue, iterator, scoringSimScorer);
     }
 
     @Override
@@ -440,12 +429,12 @@ public final class BM25FQuery extends Query implements Accountable {
     }
   }
 
-  private static class BM25FScorer extends Scorer {
+  private static class CombinedFieldScorer extends Scorer {
     private final DisiPriorityQueue queue;
     private final DocIdSetIterator iterator;
     private final MultiNormsLeafSimScorer simScorer;
 
-    BM25FScorer(
+    CombinedFieldScorer(
         Weight weight,
         DisiPriorityQueue queue,
         DocIdSetIterator iterator,
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiNormsLeafSimScorer.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiNormsLeafSimScorer.java
index 9fa8f19..e9fde6f 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiNormsLeafSimScorer.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiNormsLeafSimScorer.java
@@ -16,7 +16,7 @@
  */
 package org.apache.lucene.sandbox.search;
 
-import static org.apache.lucene.sandbox.search.BM25FQuery.FieldAndWeight;
+import static org.apache.lucene.sandbox.search.CombinedFieldQuery.FieldAndWeight;
 
 import java.io.IOException;
 import java.util.ArrayList;
diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestBM25FQuery.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestCombinedFieldQuery.java
similarity index 64%
rename from lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestBM25FQuery.java
rename to lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestCombinedFieldQuery.java
index 0f96fcd..f3c0044 100644
--- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestBM25FQuery.java
+++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestCombinedFieldQuery.java
@@ -16,8 +16,9 @@
  */
 package org.apache.lucene.sandbox.search;
 
+import com.carrotsearch.randomizedtesting.generators.RandomPicks;
 import java.io.IOException;
-import org.apache.lucene.analysis.MockAnalyzer;
+import java.util.Arrays;
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field.Store;
 import org.apache.lucene.document.StringField;
@@ -40,21 +41,25 @@ import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.TopScoreDocCollector;
 import org.apache.lucene.search.TotalHits;
 import org.apache.lucene.search.similarities.BM25Similarity;
+import org.apache.lucene.search.similarities.BooleanSimilarity;
+import org.apache.lucene.search.similarities.ClassicSimilarity;
+import org.apache.lucene.search.similarities.LMDirichletSimilarity;
+import org.apache.lucene.search.similarities.LMJelinekMercerSimilarity;
 import org.apache.lucene.search.similarities.Similarity;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.LuceneTestCase;
 
-public class TestBM25FQuery extends LuceneTestCase {
+public class TestCombinedFieldQuery extends LuceneTestCase {
   public void testInvalid() {
-    BM25FQuery.Builder builder = new BM25FQuery.Builder();
+    CombinedFieldQuery.Builder builder = new CombinedFieldQuery.Builder();
     IllegalArgumentException exc =
         expectThrows(IllegalArgumentException.class, () -> builder.addField("foo", 0.5f));
     assertEquals(exc.getMessage(), "weight must be greater or equal to 1");
   }
 
   public void testRewrite() throws IOException {
-    BM25FQuery.Builder builder = new BM25FQuery.Builder();
+    CombinedFieldQuery.Builder builder = new CombinedFieldQuery.Builder();
     IndexReader reader = new MultiReader();
     IndexSearcher searcher = new IndexSearcher(reader);
     Query actual = searcher.rewrite(builder.build());
@@ -80,21 +85,25 @@ public class TestBM25FQuery extends LuceneTestCase {
   }
 
   public void testToString() {
-    assertEquals("BM25F(()())", new BM25FQuery.Builder().build().toString());
-    BM25FQuery.Builder builder = new BM25FQuery.Builder();
+    assertEquals("CombinedFieldQuery(()())", new CombinedFieldQuery.Builder().build().toString());
+    CombinedFieldQuery.Builder builder = new CombinedFieldQuery.Builder();
     builder.addField("foo", 1f);
-    assertEquals("BM25F((foo)())", builder.build().toString());
+    assertEquals("CombinedFieldQuery((foo)())", builder.build().toString());
     builder.addTerm(new BytesRef("bar"));
-    assertEquals("BM25F((foo)(bar))", builder.build().toString());
+    assertEquals("CombinedFieldQuery((foo)(bar))", builder.build().toString());
     builder.addField("title", 3f);
-    assertEquals("BM25F((foo title^3.0)(bar))", builder.build().toString());
+    assertEquals("CombinedFieldQuery((foo title^3.0)(bar))", builder.build().toString());
     builder.addTerm(new BytesRef("baz"));
-    assertEquals("BM25F((foo title^3.0)(bar baz))", builder.build().toString());
+    assertEquals("CombinedFieldQuery((foo title^3.0)(bar baz))", builder.build().toString());
   }
 
   public void testSameScore() throws IOException {
     Directory dir = newDirectory();
-    RandomIndexWriter w = new RandomIndexWriter(random(), dir);
+    Similarity similarity = randomCompatibleSimilarity();
+
+    IndexWriterConfig iwc = new IndexWriterConfig();
+    iwc.setSimilarity(similarity);
+    RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
 
     Document doc = new Document();
     doc.add(new StringField("f", "a", Store.NO));
@@ -108,8 +117,9 @@ public class TestBM25FQuery extends LuceneTestCase {
 
     IndexReader reader = w.getReader();
     IndexSearcher searcher = newSearcher(reader);
-    BM25FQuery query =
-        new BM25FQuery.Builder()
+    searcher.setSimilarity(similarity);
+    CombinedFieldQuery query =
+        new CombinedFieldQuery.Builder()
             .addField("f", 1f)
             .addField("g", 1f)
             .addTerm(new BytesRef("a"))
@@ -130,9 +140,14 @@ public class TestBM25FQuery extends LuceneTestCase {
     dir.close();
   }
 
-  public void testAgainstCopyField() throws IOException {
+  public void testCopyField() throws IOException {
     Directory dir = newDirectory();
-    RandomIndexWriter w = new RandomIndexWriter(random(), dir, new MockAnalyzer(random()));
+    Similarity similarity = randomCompatibleSimilarity();
+
+    IndexWriterConfig iwc = new IndexWriterConfig();
+    iwc.setSimilarity(similarity);
+    RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
+
     int numMatch = atLeast(10);
     int boost1 = Math.max(1, random().nextInt(5));
     int boost2 = Math.max(1, random().nextInt(5));
@@ -163,31 +178,95 @@ public class TestBM25FQuery extends LuceneTestCase {
     }
     IndexReader reader = w.getReader();
     IndexSearcher searcher = newSearcher(reader);
-    searcher.setSimilarity(new BM25Similarity());
-    BM25FQuery query =
-        new BM25FQuery.Builder()
+
+    searcher.setSimilarity(similarity);
+    CombinedFieldQuery query =
+        new CombinedFieldQuery.Builder()
             .addField("a", (float) boost1)
             .addField("b", (float) boost2)
             .addTerm(new BytesRef("foo"))
+            .build();
+
+    checkExpectedHits(searcher, numMatch, query, new TermQuery(new Term("ab", "foo")));
+
+    reader.close();
+    w.close();
+    dir.close();
+  }
+
+  public void testCopyFieldWithMultipleTerms() throws IOException {
+    Directory dir = newDirectory();
+    Similarity similarity = randomCompatibleSimilarity();
+
+    IndexWriterConfig iwc = new IndexWriterConfig();
+    iwc.setSimilarity(similarity);
+    RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
+
+    int numMatch = atLeast(10);
+    int boost1 = Math.max(1, random().nextInt(5));
+    int boost2 = Math.max(1, random().nextInt(5));
+    for (int i = 0; i < numMatch; i++) {
+      Document doc = new Document();
+
+      int freqA = random().nextInt(5) + 1;
+      for (int j = 0; j < freqA; j++) {
+        doc.add(new TextField("a", "foo", Store.NO));
+      }
+      int freqB = random().nextInt(5) + 1;
+      for (int j = 0; j < freqB; j++) {
+        doc.add(new TextField("b", "bar", Store.NO));
+      }
+      int freqAB = freqA * boost1 + freqB * boost2;
+      for (int j = 0; j < freqAB; j++) {
+        doc.add(new TextField("ab", "foo", Store.NO));
+      }
+      w.addDocument(doc);
+    }
+    IndexReader reader = w.getReader();
+    IndexSearcher searcher = newSearcher(reader);
+
+    searcher.setSimilarity(similarity);
+    CombinedFieldQuery query =
+        new CombinedFieldQuery.Builder()
+            .addField("a", (float) boost1)
+            .addField("b", (float) boost2)
             .addTerm(new BytesRef("foo"))
+            .addTerm(new BytesRef("bar"))
             .build();
 
-    TopScoreDocCollector bm25FCollector =
-        TopScoreDocCollector.create(numMatch, null, Integer.MAX_VALUE);
-    searcher.search(query, bm25FCollector);
-    TopDocs bm25FTopDocs = bm25FCollector.topDocs();
-    assertEquals(numMatch, bm25FTopDocs.totalHits.value);
-    TopScoreDocCollector collector =
-        TopScoreDocCollector.create(reader.numDocs(), null, Integer.MAX_VALUE);
-    searcher.search(new TermQuery(new Term("ab", "foo")), collector);
-    TopDocs topDocs = collector.topDocs();
-    CheckHits.checkEqual(query, topDocs.scoreDocs, bm25FTopDocs.scoreDocs);
+    checkExpectedHits(searcher, numMatch, query, new TermQuery(new Term("ab", "foo")));
 
     reader.close();
     w.close();
     dir.close();
   }
 
+  private static Similarity randomCompatibleSimilarity() {
+    return RandomPicks.randomFrom(
+        random(),
+        Arrays.asList(
+            new BM25Similarity(),
+            new BooleanSimilarity(),
+            new ClassicSimilarity(),
+            new LMDirichletSimilarity(),
+            new LMJelinekMercerSimilarity(0.1f)));
+  }
+
+  private void checkExpectedHits(
+      IndexSearcher searcher, int numHits, Query firstQuery, Query secondQuery) throws IOException {
+    TopScoreDocCollector firstCollector =
+        TopScoreDocCollector.create(numHits, null, Integer.MAX_VALUE);
+    searcher.search(firstQuery, firstCollector);
+    TopDocs firstTopDocs = firstCollector.topDocs();
+    assertEquals(numHits, firstTopDocs.totalHits.value);
+
+    TopScoreDocCollector secondCollector =
+        TopScoreDocCollector.create(numHits, null, Integer.MAX_VALUE);
+    searcher.search(secondQuery, secondCollector);
+    TopDocs secondTopDocs = secondCollector.topDocs();
+    CheckHits.checkEqual(firstQuery, secondTopDocs.scoreDocs, firstTopDocs.scoreDocs);
+  }
+
   public void testDocWithNegativeNorms() throws IOException {
     Directory dir = newDirectory();
     IndexWriterConfig iwc = new IndexWriterConfig();
@@ -204,8 +283,9 @@ public class TestBM25FQuery extends LuceneTestCase {
 
     IndexReader reader = w.getReader();
     IndexSearcher searcher = newSearcher(reader);
-    BM25FQuery query =
-        new BM25FQuery.Builder()
+    searcher.setSimilarity(new BM25Similarity());
+    CombinedFieldQuery query =
+        new CombinedFieldQuery.Builder()
             .addField("f")
             .addField("g")
             .addTerm(new BytesRef(queryString))
@@ -239,8 +319,9 @@ public class TestBM25FQuery extends LuceneTestCase {
 
     IndexReader reader = w.getReader();
     IndexSearcher searcher = newSearcher(reader);
-    BM25FQuery query =
-        new BM25FQuery.Builder()
+    searcher.setSimilarity(new BM25Similarity());
+    CombinedFieldQuery query =
+        new CombinedFieldQuery.Builder()
             .addField("f")
             .addField("g")
             .addTerm(new BytesRef(queryString))