You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2021/01/11 14:03:51 UTC

[lucene-solr] branch master updated: LUCENE-9346: Support minimumNumberShouldMatch in WANDScorer (#2141)

This is an automated email from the ASF dual-hosted git repository.

jpountz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git


The following commit(s) were added to refs/heads/master by this push:
     new c249328  LUCENE-9346: Support minimumNumberShouldMatch in WANDScorer (#2141)
c249328 is described below

commit c2493283a58ea19a13887a732328c1eaf970d371
Author: zacharymorn <za...@yahoo.com>
AuthorDate: Mon Jan 11 06:03:29 2021 -0800

    LUCENE-9346: Support minimumNumberShouldMatch in WANDScorer (#2141)
    
    Co-authored-by: Adrien Grand <jp...@gmail.com>
---
 .../lucene/search/Boolean2ScorerSupplier.java      |   9 +-
 .../lucene/search/MinShouldMatchSumScorer.java     |  30 +--
 .../java/org/apache/lucene/search/ScorerUtil.java  |  49 +++++
 .../java/org/apache/lucene/search/WANDScorer.java  |  40 +++-
 .../org/apache/lucene/search/TestWANDScorer.java   | 225 +++++++++++++++++++++
 5 files changed, 309 insertions(+), 44 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java
index 3fa5886..d0a4dbd 100644
--- a/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java
+++ b/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java
@@ -75,7 +75,7 @@ final class Boolean2ScorerSupplier extends ScorerSupplier {
     } else {
       final Collection<ScorerSupplier> optionalScorers = subs.get(Occur.SHOULD);
       final long shouldCost =
-          MinShouldMatchSumScorer.cost(
+          ScorerUtil.costWithMinShouldMatch(
               optionalScorers.stream().mapToLong(ScorerSupplier::cost),
               optionalScorers.size(),
               minShouldMatch);
@@ -230,10 +230,11 @@ final class Boolean2ScorerSupplier extends ScorerSupplier {
       for (ScorerSupplier scorer : optional) {
         optionalScorers.add(scorer.get(leadCost));
       }
-      if (minShouldMatch > 1) {
+
+      if (scoreMode == ScoreMode.TOP_SCORES) {
+        return new WANDScorer(weight, optionalScorers, minShouldMatch);
+      } else if (minShouldMatch > 1) {
         return new MinShouldMatchSumScorer(weight, optionalScorers, minShouldMatch);
-      } else if (scoreMode == ScoreMode.TOP_SCORES) {
-        return new WANDScorer(weight, optionalScorers);
       } else {
         return new DisjunctionSumScorer(weight, optionalScorers, scoreMode);
       }
diff --git a/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java b/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java
index bdcdca9..574fd1a 100644
--- a/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java
@@ -24,9 +24,6 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
-import java.util.stream.LongStream;
-import java.util.stream.StreamSupport;
-import org.apache.lucene.util.PriorityQueue;
 
 /**
  * A {@link Scorer} for {@link BooleanQuery} when {@link
@@ -44,31 +41,6 @@ import org.apache.lucene.util.PriorityQueue;
  */
 final class MinShouldMatchSumScorer extends Scorer {
 
-  static long cost(LongStream costs, int numScorers, int minShouldMatch) {
-    // the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m
-    // could be rewritten to:
-    // (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m))
-    // if we assume that clauses come in ascending cost, then
-    // the cost of the first part is the cost of c1 (because the cost of a conjunction is
-    // the cost of the least costly clause)
-    // the cost of the second part is the cost of finding m matches among the c2...cn
-    // remaining clauses
-    // since it is a disjunction overall, the total cost is the sum of the costs of these
-    // two parts
-
-    // If we recurse infinitely, we find out that the cost of a msm query is the sum of the
-    // costs of the num_scorers - minShouldMatch + 1 least costly scorers
-    final PriorityQueue<Long> pq =
-        new PriorityQueue<Long>(numScorers - minShouldMatch + 1) {
-          @Override
-          protected boolean lessThan(Long a, Long b) {
-            return a > b;
-          }
-        };
-    costs.forEach(pq::insertWithOverflow);
-    return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum();
-  }
-
   final int minShouldMatch;
 
   // list of scorers which 'lead' the iteration and are currently
@@ -111,7 +83,7 @@ final class MinShouldMatchSumScorer extends Scorer {
     }
 
     this.cost =
-        cost(
+        ScorerUtil.costWithMinShouldMatch(
             scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost),
             scorers.size(),
             minShouldMatch);
diff --git a/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java b/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java
new file mode 100644
index 0000000..50c9607
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.util.stream.LongStream;
+import java.util.stream.StreamSupport;
+import org.apache.lucene.util.PriorityQueue;
+
+/** Util class for Scorer related methods */
+class ScorerUtil {
+  static long costWithMinShouldMatch(LongStream costs, int numScorers, int minShouldMatch) {
+    // the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m
+    // could be rewritten to:
+    // (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m))
+    // if we assume that clauses come in ascending cost, then
+    // the cost of the first part is the cost of c1 (because the cost of a conjunction is
+    // the cost of the least costly clause)
+    // the cost of the second part is the cost of finding m matches among the c2...cn
+    // remaining clauses
+    // since it is a disjunction overall, the total cost is the sum of the costs of these
+    // two parts
+
+    // If we recurse infinitely, we find out that the cost of a msm query is the sum of the
+    // costs of the num_scorers - minShouldMatch + 1 least costly scorers
+    final PriorityQueue<Long> pq =
+        new PriorityQueue<Long>(numScorers - minShouldMatch + 1) {
+          @Override
+          protected boolean lessThan(Long a, Long b) {
+            return a > b;
+          }
+        };
+    costs.forEach(pq::insertWithOverflow);
+    return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum();
+  }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java b/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java
index 2c94159..b1ed3bf 100644
--- a/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java
@@ -19,6 +19,7 @@ package org.apache.lucene.search;
 import static org.apache.lucene.search.DisiPriorityQueue.leftNode;
 import static org.apache.lucene.search.DisiPriorityQueue.parentNode;
 import static org.apache.lucene.search.DisiPriorityQueue.rightNode;
+import static org.apache.lucene.search.ScorerUtil.costWithMinShouldMatch;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -130,10 +131,21 @@ final class WANDScorer extends Scorer {
 
   int upTo; // upper bound for which max scores are valid
 
-  WANDScorer(Weight weight, Collection<Scorer> scorers) throws IOException {
+  final int minShouldMatch;
+  int freq;
+
+  WANDScorer(Weight weight, Collection<Scorer> scorers, int minShouldMatch) throws IOException {
     super(weight);
 
+    if (minShouldMatch >= scorers.size()) {
+      throw new IllegalArgumentException("minShouldMatch should be < the number of scorers");
+    }
+
     this.minCompetitiveScore = 0;
+
+    assert minShouldMatch >= 0 : "minShouldMatch should not be negative, but got " + minShouldMatch;
+    this.minShouldMatch = minShouldMatch;
+
     this.doc = -1;
     this.upTo = -1; // will be computed on the first call to nextDoc/advance
 
@@ -155,13 +167,15 @@ final class WANDScorer extends Scorer {
     // Use a scaling factor of 0 if all max scores are either 0 or +Infty
     this.scalingFactor = scalingFactor.orElse(0);
 
-    long cost = 0;
     for (Scorer scorer : scorers) {
-      DisiWrapper w = new DisiWrapper(scorer);
-      cost += w.cost;
-      addLead(w);
+      addLead(new DisiWrapper(scorer));
     }
-    this.cost = cost;
+
+    this.cost =
+        costWithMinShouldMatch(
+            scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost),
+            scorers.size(),
+            minShouldMatch);
     this.maxScorePropagator = new MaxScoreSumPropagator(scorers);
   }
 
@@ -265,15 +279,17 @@ final class WANDScorer extends Scorer {
 
       @Override
       public boolean matches() throws IOException {
-        while (leadMaxScore < minCompetitiveScore) {
-          if (leadMaxScore + tailMaxScore >= minCompetitiveScore) {
+        while (leadMaxScore < minCompetitiveScore || freq < minShouldMatch) {
+          if (leadMaxScore + tailMaxScore < minCompetitiveScore
+              || freq + tailSize < minShouldMatch) {
+            return false;
+          } else {
             // a match on doc is still possible, try to
             // advance scorers from the tail
             advanceTail();
-          } else {
-            return false;
           }
         }
+
         return true;
       }
 
@@ -290,6 +306,7 @@ final class WANDScorer extends Scorer {
     lead.next = this.lead;
     this.lead = lead;
     leadMaxScore += lead.maxScore;
+    freq += 1;
   }
 
   /** Move disis that are in 'lead' back to the tail. */
@@ -429,6 +446,7 @@ final class WANDScorer extends Scorer {
     lead = head.pop();
     lead.next = null;
     leadMaxScore = lead.maxScore;
+    freq = 1;
     doc = lead.doc;
     while (head.size() > 0 && head.top().doc == doc) {
       addLead(head.pop());
@@ -437,7 +455,7 @@ final class WANDScorer extends Scorer {
 
   /** Move iterators to the tail until there is a potential match. */
   private int doNextCompetitiveCandidate() throws IOException {
-    while (leadMaxScore + tailMaxScore < minCompetitiveScore) {
+    while (leadMaxScore + tailMaxScore < minCompetitiveScore || freq + tailSize < minShouldMatch) {
       // no match on doc is possible, move to the next potential match
       pushBackLeads(doc + 1);
       moveToNextCandidate(doc + 1);
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java
index 543237b..c9381fe 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java
@@ -229,6 +229,231 @@ public class TestWANDScorer extends LuceneTestCase {
     dir.close();
   }
 
+  public void testBasicsWithDisjunctionAndMinShouldMatch() throws Exception {
+    try (Directory dir = newDirectory()) {
+      try (IndexWriter w =
+          new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) {
+        for (String[] values :
+            Arrays.asList(
+                new String[] {"A", "B"}, // 0
+                new String[] {"A"}, // 1
+                new String[] {}, // 2
+                new String[] {"A", "B", "C"}, // 3
+                new String[] {"B"}, // 4
+                new String[] {"B", "C"} // 5
+                )) {
+          Document doc = new Document();
+          for (String value : values) {
+            doc.add(new StringField("foo", value, Store.NO));
+          }
+          w.addDocument(doc);
+        }
+
+        w.forceMerge(1);
+      }
+
+      try (IndexReader reader = DirectoryReader.open(dir)) {
+        IndexSearcher searcher = newSearcher(reader);
+
+        Query query =
+            new BooleanQuery.Builder()
+                .add(
+                    new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
+                    Occur.SHOULD)
+                .add(new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))), Occur.SHOULD)
+                .add(
+                    new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "C"))), 3),
+                    Occur.SHOULD)
+                .setMinimumNumberShouldMatch(2)
+                .build();
+
+        Scorer scorer =
+            searcher
+                .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+                .scorer(searcher.getIndexReader().leaves().get(0));
+
+        assertEquals(0, scorer.iterator().nextDoc());
+        assertEquals(2 + 1, scorer.score(), 0);
+
+        assertEquals(3, scorer.iterator().nextDoc());
+        assertEquals(2 + 1 + 3, scorer.score(), 0);
+
+        assertEquals(5, scorer.iterator().nextDoc());
+        assertEquals(1 + 3, scorer.score(), 0);
+
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+
+        scorer =
+            searcher
+                .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+                .scorer(searcher.getIndexReader().leaves().get(0));
+        scorer.setMinCompetitiveScore(4);
+
+        assertEquals(3, scorer.iterator().nextDoc());
+        assertEquals(2 + 1 + 3, scorer.score(), 0);
+
+        assertEquals(5, scorer.iterator().nextDoc());
+        assertEquals(1 + 3, scorer.score(), 0);
+
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+
+        scorer =
+            searcher
+                .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+                .scorer(searcher.getIndexReader().leaves().get(0));
+
+        assertEquals(0, scorer.iterator().nextDoc());
+        assertEquals(2 + 1, scorer.score(), 0);
+
+        scorer.setMinCompetitiveScore(10);
+
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+      }
+    }
+  }
+
+  public void testBasicsWithFilteredDisjunctionAndMinShouldMatch() throws Exception {
+    try (Directory dir = newDirectory()) {
+      try (IndexWriter w =
+          new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) {
+        for (String[] values :
+            Arrays.asList(
+                new String[] {"A", "B"}, // 0
+                new String[] {"A", "C", "D"}, // 1
+                new String[] {}, // 2
+                new String[] {"A", "B", "C", "D"}, // 3
+                new String[] {"B"}, // 4
+                new String[] {"C", "D"} // 5
+                )) {
+          Document doc = new Document();
+          for (String value : values) {
+            doc.add(new StringField("foo", value, Store.NO));
+          }
+          w.addDocument(doc);
+        }
+
+        w.forceMerge(1);
+      }
+
+      try (IndexReader reader = DirectoryReader.open(dir)) {
+        IndexSearcher searcher = newSearcher(reader);
+
+        Query query =
+            new BooleanQuery.Builder()
+                .add(
+                    new BooleanQuery.Builder()
+                        .add(
+                            new BoostQuery(
+                                new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
+                            Occur.SHOULD)
+                        .add(
+                            new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
+                            Occur.SHOULD)
+                        .add(
+                            new BoostQuery(
+                                new ConstantScoreQuery(new TermQuery(new Term("foo", "D"))), 4),
+                            Occur.SHOULD)
+                        .setMinimumNumberShouldMatch(2)
+                        .build(),
+                    Occur.MUST)
+                .add(new TermQuery(new Term("foo", "C")), Occur.FILTER)
+                .build();
+
+        Scorer scorer =
+            searcher
+                .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+                .scorer(searcher.getIndexReader().leaves().get(0));
+
+        assertEquals(1, scorer.iterator().nextDoc());
+        assertEquals(2 + 4, scorer.score(), 0);
+
+        assertEquals(3, scorer.iterator().nextDoc());
+        assertEquals(2 + 1 + 4, scorer.score(), 0);
+
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+
+        scorer =
+            searcher
+                .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+                .scorer(searcher.getIndexReader().leaves().get(0));
+
+        scorer.setMinCompetitiveScore(2 + 1 + 4);
+
+        assertEquals(3, scorer.iterator().nextDoc());
+        assertEquals(2 + 1 + 4, scorer.score(), 0);
+
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+      }
+    }
+  }
+
+  public void testBasicsWithFilteredDisjunctionAndMustNotAndMinShouldMatch() throws Exception {
+    try (Directory dir = newDirectory()) {
+      try (IndexWriter w =
+          new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) {
+        for (String[] values :
+            Arrays.asList(
+                new String[] {"A", "B"}, // 0
+                new String[] {"A", "C", "D"}, // 1
+                new String[] {}, // 2
+                new String[] {"A", "B", "C", "D"}, // 3
+                new String[] {"B", "D"}, // 4
+                new String[] {"C", "D"} // 5
+                )) {
+          Document doc = new Document();
+          for (String value : values) {
+            doc.add(new StringField("foo", value, Store.NO));
+          }
+          w.addDocument(doc);
+        }
+
+        w.forceMerge(1);
+      }
+
+      try (IndexReader reader = DirectoryReader.open(dir)) {
+        IndexSearcher searcher = newSearcher(reader);
+
+        Query query =
+            new BooleanQuery.Builder()
+                .add(
+                    new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
+                    Occur.SHOULD)
+                .add(new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))), Occur.SHOULD)
+                .add(new TermQuery(new Term("foo", "C")), Occur.MUST_NOT)
+                .add(
+                    new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "D"))), 4),
+                    Occur.SHOULD)
+                .setMinimumNumberShouldMatch(2)
+                .build();
+
+        Scorer scorer =
+            searcher
+                .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+                .scorer(searcher.getIndexReader().leaves().get(0));
+
+        assertEquals(0, scorer.iterator().nextDoc());
+        assertEquals(2 + 1, scorer.score(), 0);
+
+        assertEquals(4, scorer.iterator().nextDoc());
+        assertEquals(1 + 4, scorer.score(), 0);
+
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+
+        scorer =
+            searcher
+                .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+                .scorer(searcher.getIndexReader().leaves().get(0));
+
+        scorer.setMinCompetitiveScore(4);
+
+        assertEquals(4, scorer.iterator().nextDoc());
+        assertEquals(1 + 4, scorer.score(), 0);
+
+        assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+      }
+    }
+  }
+
   public void testRandom() throws IOException {
     Directory dir = newDirectory();
     IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());