You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by dw...@apache.org on 2021/03/10 09:58:48 UTC
[lucene] 33/49: LUCENE-9346: Support minimumNumberShouldMatch in
WANDScorer (#2141)
This is an automated email from the ASF dual-hosted git repository.
dweiss pushed a commit to branch jira/solr-13105-toMerge
in repository https://gitbox.apache.org/repos/asf/lucene.git
commit 7c03cae553a0bc83d90920804c6714204eaa8391
Author: zacharymorn <za...@yahoo.com>
AuthorDate: Mon Jan 11 06:03:29 2021 -0800
LUCENE-9346: Support minimumNumberShouldMatch in WANDScorer (#2141)
Co-authored-by: Adrien Grand <jp...@gmail.com>
---
.../lucene/search/Boolean2ScorerSupplier.java | 9 +-
.../lucene/search/MinShouldMatchSumScorer.java | 30 +--
.../java/org/apache/lucene/search/ScorerUtil.java | 49 +++++
.../java/org/apache/lucene/search/WANDScorer.java | 40 +++-
.../org/apache/lucene/search/TestWANDScorer.java | 225 +++++++++++++++++++++
5 files changed, 309 insertions(+), 44 deletions(-)
diff --git a/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java
index 3fa5886..d0a4dbd 100644
--- a/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java
+++ b/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java
@@ -75,7 +75,7 @@ final class Boolean2ScorerSupplier extends ScorerSupplier {
} else {
final Collection<ScorerSupplier> optionalScorers = subs.get(Occur.SHOULD);
final long shouldCost =
- MinShouldMatchSumScorer.cost(
+ ScorerUtil.costWithMinShouldMatch(
optionalScorers.stream().mapToLong(ScorerSupplier::cost),
optionalScorers.size(),
minShouldMatch);
@@ -230,10 +230,11 @@ final class Boolean2ScorerSupplier extends ScorerSupplier {
for (ScorerSupplier scorer : optional) {
optionalScorers.add(scorer.get(leadCost));
}
- if (minShouldMatch > 1) {
+
+ if (scoreMode == ScoreMode.TOP_SCORES) {
+ return new WANDScorer(weight, optionalScorers, minShouldMatch);
+ } else if (minShouldMatch > 1) {
return new MinShouldMatchSumScorer(weight, optionalScorers, minShouldMatch);
- } else if (scoreMode == ScoreMode.TOP_SCORES) {
- return new WANDScorer(weight, optionalScorers);
} else {
return new DisjunctionSumScorer(weight, optionalScorers, scoreMode);
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java b/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java
index bdcdca9..574fd1a 100644
--- a/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java
@@ -24,9 +24,6 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
-import java.util.stream.LongStream;
-import java.util.stream.StreamSupport;
-import org.apache.lucene.util.PriorityQueue;
/**
* A {@link Scorer} for {@link BooleanQuery} when {@link
@@ -44,31 +41,6 @@ import org.apache.lucene.util.PriorityQueue;
*/
final class MinShouldMatchSumScorer extends Scorer {
- static long cost(LongStream costs, int numScorers, int minShouldMatch) {
- // the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m
- // could be rewritten to:
- // (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m))
- // if we assume that clauses come in ascending cost, then
- // the cost of the first part is the cost of c1 (because the cost of a conjunction is
- // the cost of the least costly clause)
- // the cost of the second part is the cost of finding m matches among the c2...cn
- // remaining clauses
- // since it is a disjunction overall, the total cost is the sum of the costs of these
- // two parts
-
- // If we recurse infinitely, we find out that the cost of a msm query is the sum of the
- // costs of the num_scorers - minShouldMatch + 1 least costly scorers
- final PriorityQueue<Long> pq =
- new PriorityQueue<Long>(numScorers - minShouldMatch + 1) {
- @Override
- protected boolean lessThan(Long a, Long b) {
- return a > b;
- }
- };
- costs.forEach(pq::insertWithOverflow);
- return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum();
- }
-
final int minShouldMatch;
// list of scorers which 'lead' the iteration and are currently
@@ -111,7 +83,7 @@ final class MinShouldMatchSumScorer extends Scorer {
}
this.cost =
- cost(
+ ScorerUtil.costWithMinShouldMatch(
scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost),
scorers.size(),
minShouldMatch);
diff --git a/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java b/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java
new file mode 100644
index 0000000..50c9607
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.util.stream.LongStream;
+import java.util.stream.StreamSupport;
+import org.apache.lucene.util.PriorityQueue;
+
+/** Util class for Scorer related methods */
+class ScorerUtil {
+ static long costWithMinShouldMatch(LongStream costs, int numScorers, int minShouldMatch) {
+ // the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m
+ // could be rewritten to:
+ // (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m))
+ // if we assume that clauses come in ascending cost, then
+ // the cost of the first part is the cost of c1 (because the cost of a conjunction is
+ // the cost of the least costly clause)
+ // the cost of the second part is the cost of finding m matches among the c2...cn
+ // remaining clauses
+ // since it is a disjunction overall, the total cost is the sum of the costs of these
+ // two parts
+
+ // If we recurse infinitely, we find out that the cost of a msm query is the sum of the
+ // costs of the num_scorers - minShouldMatch + 1 least costly scorers
+ final PriorityQueue<Long> pq =
+ new PriorityQueue<Long>(numScorers - minShouldMatch + 1) {
+ @Override
+ protected boolean lessThan(Long a, Long b) {
+ return a > b;
+ }
+ };
+ costs.forEach(pq::insertWithOverflow);
+ return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum();
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java b/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java
index 2c94159..b1ed3bf 100644
--- a/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java
@@ -19,6 +19,7 @@ package org.apache.lucene.search;
import static org.apache.lucene.search.DisiPriorityQueue.leftNode;
import static org.apache.lucene.search.DisiPriorityQueue.parentNode;
import static org.apache.lucene.search.DisiPriorityQueue.rightNode;
+import static org.apache.lucene.search.ScorerUtil.costWithMinShouldMatch;
import java.io.IOException;
import java.util.ArrayList;
@@ -130,10 +131,21 @@ final class WANDScorer extends Scorer {
int upTo; // upper bound for which max scores are valid
- WANDScorer(Weight weight, Collection<Scorer> scorers) throws IOException {
+ final int minShouldMatch;
+ int freq;
+
+ WANDScorer(Weight weight, Collection<Scorer> scorers, int minShouldMatch) throws IOException {
super(weight);
+ if (minShouldMatch >= scorers.size()) {
+ throw new IllegalArgumentException("minShouldMatch should be < the number of scorers");
+ }
+
this.minCompetitiveScore = 0;
+
+ assert minShouldMatch >= 0 : "minShouldMatch should not be negative, but got " + minShouldMatch;
+ this.minShouldMatch = minShouldMatch;
+
this.doc = -1;
this.upTo = -1; // will be computed on the first call to nextDoc/advance
@@ -155,13 +167,15 @@ final class WANDScorer extends Scorer {
// Use a scaling factor of 0 if all max scores are either 0 or +Infty
this.scalingFactor = scalingFactor.orElse(0);
- long cost = 0;
for (Scorer scorer : scorers) {
- DisiWrapper w = new DisiWrapper(scorer);
- cost += w.cost;
- addLead(w);
+ addLead(new DisiWrapper(scorer));
}
- this.cost = cost;
+
+ this.cost =
+ costWithMinShouldMatch(
+ scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost),
+ scorers.size(),
+ minShouldMatch);
this.maxScorePropagator = new MaxScoreSumPropagator(scorers);
}
@@ -265,15 +279,17 @@ final class WANDScorer extends Scorer {
@Override
public boolean matches() throws IOException {
- while (leadMaxScore < minCompetitiveScore) {
- if (leadMaxScore + tailMaxScore >= minCompetitiveScore) {
+ while (leadMaxScore < minCompetitiveScore || freq < minShouldMatch) {
+ if (leadMaxScore + tailMaxScore < minCompetitiveScore
+ || freq + tailSize < minShouldMatch) {
+ return false;
+ } else {
// a match on doc is still possible, try to
// advance scorers from the tail
advanceTail();
- } else {
- return false;
}
}
+
return true;
}
@@ -290,6 +306,7 @@ final class WANDScorer extends Scorer {
lead.next = this.lead;
this.lead = lead;
leadMaxScore += lead.maxScore;
+ freq += 1;
}
/** Move disis that are in 'lead' back to the tail. */
@@ -429,6 +446,7 @@ final class WANDScorer extends Scorer {
lead = head.pop();
lead.next = null;
leadMaxScore = lead.maxScore;
+ freq = 1;
doc = lead.doc;
while (head.size() > 0 && head.top().doc == doc) {
addLead(head.pop());
@@ -437,7 +455,7 @@ final class WANDScorer extends Scorer {
/** Move iterators to the tail until there is a potential match. */
private int doNextCompetitiveCandidate() throws IOException {
- while (leadMaxScore + tailMaxScore < minCompetitiveScore) {
+ while (leadMaxScore + tailMaxScore < minCompetitiveScore || freq + tailSize < minShouldMatch) {
// no match on doc is possible, move to the next potential match
pushBackLeads(doc + 1);
moveToNextCandidate(doc + 1);
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java
index 543237b..c9381fe 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java
@@ -229,6 +229,231 @@ public class TestWANDScorer extends LuceneTestCase {
dir.close();
}
+ public void testBasicsWithDisjunctionAndMinShouldMatch() throws Exception {
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w =
+ new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) {
+ for (String[] values :
+ Arrays.asList(
+ new String[] {"A", "B"}, // 0
+ new String[] {"A"}, // 1
+ new String[] {}, // 2
+ new String[] {"A", "B", "C"}, // 3
+ new String[] {"B"}, // 4
+ new String[] {"B", "C"} // 5
+ )) {
+ Document doc = new Document();
+ for (String value : values) {
+ doc.add(new StringField("foo", value, Store.NO));
+ }
+ w.addDocument(doc);
+ }
+
+ w.forceMerge(1);
+ }
+
+ try (IndexReader reader = DirectoryReader.open(dir)) {
+ IndexSearcher searcher = newSearcher(reader);
+
+ Query query =
+ new BooleanQuery.Builder()
+ .add(
+ new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
+ Occur.SHOULD)
+ .add(new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))), Occur.SHOULD)
+ .add(
+ new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "C"))), 3),
+ Occur.SHOULD)
+ .setMinimumNumberShouldMatch(2)
+ .build();
+
+ Scorer scorer =
+ searcher
+ .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+ .scorer(searcher.getIndexReader().leaves().get(0));
+
+ assertEquals(0, scorer.iterator().nextDoc());
+ assertEquals(2 + 1, scorer.score(), 0);
+
+ assertEquals(3, scorer.iterator().nextDoc());
+ assertEquals(2 + 1 + 3, scorer.score(), 0);
+
+ assertEquals(5, scorer.iterator().nextDoc());
+ assertEquals(1 + 3, scorer.score(), 0);
+
+ assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+
+ scorer =
+ searcher
+ .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+ .scorer(searcher.getIndexReader().leaves().get(0));
+ scorer.setMinCompetitiveScore(4);
+
+ assertEquals(3, scorer.iterator().nextDoc());
+ assertEquals(2 + 1 + 3, scorer.score(), 0);
+
+ assertEquals(5, scorer.iterator().nextDoc());
+ assertEquals(1 + 3, scorer.score(), 0);
+
+ assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+
+ scorer =
+ searcher
+ .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+ .scorer(searcher.getIndexReader().leaves().get(0));
+
+ assertEquals(0, scorer.iterator().nextDoc());
+ assertEquals(2 + 1, scorer.score(), 0);
+
+ scorer.setMinCompetitiveScore(10);
+
+ assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+ }
+ }
+ }
+
+ public void testBasicsWithFilteredDisjunctionAndMinShouldMatch() throws Exception {
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w =
+ new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) {
+ for (String[] values :
+ Arrays.asList(
+ new String[] {"A", "B"}, // 0
+ new String[] {"A", "C", "D"}, // 1
+ new String[] {}, // 2
+ new String[] {"A", "B", "C", "D"}, // 3
+ new String[] {"B"}, // 4
+ new String[] {"C", "D"} // 5
+ )) {
+ Document doc = new Document();
+ for (String value : values) {
+ doc.add(new StringField("foo", value, Store.NO));
+ }
+ w.addDocument(doc);
+ }
+
+ w.forceMerge(1);
+ }
+
+ try (IndexReader reader = DirectoryReader.open(dir)) {
+ IndexSearcher searcher = newSearcher(reader);
+
+ Query query =
+ new BooleanQuery.Builder()
+ .add(
+ new BooleanQuery.Builder()
+ .add(
+ new BoostQuery(
+ new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
+ Occur.SHOULD)
+ .add(
+ new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))),
+ Occur.SHOULD)
+ .add(
+ new BoostQuery(
+ new ConstantScoreQuery(new TermQuery(new Term("foo", "D"))), 4),
+ Occur.SHOULD)
+ .setMinimumNumberShouldMatch(2)
+ .build(),
+ Occur.MUST)
+ .add(new TermQuery(new Term("foo", "C")), Occur.FILTER)
+ .build();
+
+ Scorer scorer =
+ searcher
+ .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+ .scorer(searcher.getIndexReader().leaves().get(0));
+
+ assertEquals(1, scorer.iterator().nextDoc());
+ assertEquals(2 + 4, scorer.score(), 0);
+
+ assertEquals(3, scorer.iterator().nextDoc());
+ assertEquals(2 + 1 + 4, scorer.score(), 0);
+
+ assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+
+ scorer =
+ searcher
+ .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+ .scorer(searcher.getIndexReader().leaves().get(0));
+
+ scorer.setMinCompetitiveScore(2 + 1 + 4);
+
+ assertEquals(3, scorer.iterator().nextDoc());
+ assertEquals(2 + 1 + 4, scorer.score(), 0);
+
+ assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+ }
+ }
+ }
+
+ public void testBasicsWithFilteredDisjunctionAndMustNotAndMinShouldMatch() throws Exception {
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w =
+ new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) {
+ for (String[] values :
+ Arrays.asList(
+ new String[] {"A", "B"}, // 0
+ new String[] {"A", "C", "D"}, // 1
+ new String[] {}, // 2
+ new String[] {"A", "B", "C", "D"}, // 3
+ new String[] {"B", "D"}, // 4
+ new String[] {"C", "D"} // 5
+ )) {
+ Document doc = new Document();
+ for (String value : values) {
+ doc.add(new StringField("foo", value, Store.NO));
+ }
+ w.addDocument(doc);
+ }
+
+ w.forceMerge(1);
+ }
+
+ try (IndexReader reader = DirectoryReader.open(dir)) {
+ IndexSearcher searcher = newSearcher(reader);
+
+ Query query =
+ new BooleanQuery.Builder()
+ .add(
+ new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2),
+ Occur.SHOULD)
+ .add(new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))), Occur.SHOULD)
+ .add(new TermQuery(new Term("foo", "C")), Occur.MUST_NOT)
+ .add(
+ new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "D"))), 4),
+ Occur.SHOULD)
+ .setMinimumNumberShouldMatch(2)
+ .build();
+
+ Scorer scorer =
+ searcher
+ .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+ .scorer(searcher.getIndexReader().leaves().get(0));
+
+ assertEquals(0, scorer.iterator().nextDoc());
+ assertEquals(2 + 1, scorer.score(), 0);
+
+ assertEquals(4, scorer.iterator().nextDoc());
+ assertEquals(1 + 4, scorer.score(), 0);
+
+ assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+
+ scorer =
+ searcher
+ .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1)
+ .scorer(searcher.getIndexReader().leaves().get(0));
+
+ scorer.setMinCompetitiveScore(4);
+
+ assertEquals(4, scorer.iterator().nextDoc());
+ assertEquals(1 + 4, scorer.score(), 0);
+
+ assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc());
+ }
+ }
+ }
+
public void testRandom() throws IOException {
Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());