You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2022/06/23 15:56:21 UTC
[lucene] branch main updated: LUCENE-10620: Pass the Weight to Collectors. (#964)
This is an automated email from the ASF dual-hosted git repository.
jpountz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/main by this push:
new 4c1ae2a332c LUCENE-10620: Pass the Weight to Collectors. (#964)
4c1ae2a332c is described below
commit 4c1ae2a332cb85878310142ebb9fd5beba0345f2
Author: Adrien Grand <jp...@gmail.com>
AuthorDate: Thu Jun 23 17:56:15 2022 +0200
LUCENE-10620: Pass the Weight to Collectors. (#964)
This allows `Collector`s to use `Weight#count` when appropriate.
---
.../java/org/apache/lucene/search/Collector.java | 7 ++
.../org/apache/lucene/search/FilterCollector.java | 5 ++
.../org/apache/lucene/search/IndexSearcher.java | 52 +-------------
.../org/apache/lucene/search/MultiCollector.java | 7 ++
.../lucene/search/TotalHitCountCollector.java | 40 ++++++++---
.../org/apache/lucene/search/TestBooleanQuery.java | 3 +-
.../apache/lucene/search/TestLRUQueryCache.java | 19 ++---
.../apache/lucene/search/TestMultiCollector.java | 16 +++--
.../lucene/search/TestSearchWithThreads.java | 3 +-
.../org/apache/lucene/search/TestTermQuery.java | 6 +-
.../lucene/search/TestTotalHitCountCollector.java | 11 +++
.../lucene/sandbox/search/ProfilerCollector.java | 6 ++
...tIndexSortSortedNumericDocValuesRangeQuery.java | 5 +-
.../sandbox/search/TestMultiRangeQueries.java | 6 +-
.../lucene/tests/search/AssertingCollector.java | 11 +++
.../tests/search/DummyTotalHitCountCollector.java | 83 ++++++++++++++++++++++
16 files changed, 196 insertions(+), 84 deletions(-)
diff --git a/lucene/core/src/java/org/apache/lucene/search/Collector.java b/lucene/core/src/java/org/apache/lucene/search/Collector.java
index 3af2210bd6e..7c02e446755 100644
--- a/lucene/core/src/java/org/apache/lucene/search/Collector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/Collector.java
@@ -56,4 +56,11 @@ public interface Collector {
/** Indicates what features are required from the scorer. */
ScoreMode scoreMode();
+
+ /**
+ * Set the {@link Weight} that will be used to produce scorers that will feed {@link
+ * LeafCollector}s. This is typically useful to have access to {@link Weight#count} from {@link
+ * Collector#getLeafCollector}.
+ */
+ default void setWeight(Weight weight) {}
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java b/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java
index 9f57a57b00a..32c52abe1b7 100644
--- a/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java
@@ -38,6 +38,11 @@ public abstract class FilterCollector implements Collector {
return in.getLeafCollector(context);
}
+ @Override
+ public void setWeight(Weight weight) {
+ in.setWeight(weight);
+ }
+
@Override
public String toString() {
return getClass().getSimpleName() + "(" + in + ")";
diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
index 4e64089f3c1..0f6bdfdeb10 100644
--- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
@@ -412,61 +412,13 @@ public class IndexSearcher {
return similarity;
}
- private static class ShortcutHitCountCollector implements Collector {
- private final Weight weight;
- private final TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
- private int weightCount;
-
- ShortcutHitCountCollector(Weight weight) {
- this.weight = weight;
- }
-
- @Override
- public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
- int count = weight.count(context);
- // check if the number of hits can be computed in constant time
- if (count == -1) {
- // use a TotalHitCountCollector to calculate the number of hits in the usual way
- return totalHitCountCollector.getLeafCollector(context);
- } else {
- weightCount += count;
- throw new CollectionTerminatedException();
- }
- }
-
- @Override
- public ScoreMode scoreMode() {
- return ScoreMode.COMPLETE_NO_SCORES;
- }
- }
-
/**
* Count how many documents match the given query. May be faster than counting number of hits by
* collecting all matches, as the number of hits is retrieved from the index statistics when
* possible.
*/
public int count(Query query) throws IOException {
- query = rewrite(query, false);
- final Weight weight = createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1);
-
- final CollectorManager<ShortcutHitCountCollector, Integer> shortcutCollectorManager =
- new CollectorManager<ShortcutHitCountCollector, Integer>() {
- @Override
- public ShortcutHitCountCollector newCollector() throws IOException {
- return new ShortcutHitCountCollector(weight);
- }
-
- @Override
- public Integer reduce(Collection<ShortcutHitCountCollector> collectors)
- throws IOException {
- int totalHitCount = 0;
- for (ShortcutHitCountCollector c : collectors) {
- totalHitCount += c.weightCount + c.totalHitCountCollector.getTotalHits();
- }
- return totalHitCount;
- }
- };
- return search(weight, shortcutCollectorManager, new ShortcutHitCountCollector(weight));
+ return search(new ConstantScoreQuery(query), new TotalHitCountCollectorManager());
}
/**
@@ -750,6 +702,8 @@ public class IndexSearcher {
protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector)
throws IOException {
+ collector.setWeight(weight);
+
// TODO: should we make this
// threaded...? the Collector could be sync'd?
// always use single thread:
diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java b/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
index 09aea3a02a3..5452c0f8d69 100644
--- a/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
@@ -149,6 +149,13 @@ public class MultiCollector implements Collector {
}
}
+ @Override
+ public void setWeight(Weight weight) {
+ for (Collector collector : collectors) {
+ collector.setWeight(weight);
+ }
+ }
+
/** Provides access to the wrapped {@code Collector}s for advanced use-cases */
public Collector[] getCollectors() {
return collectors;
diff --git a/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java b/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java
index 9d9ad4149b0..30d0659f2cd 100644
--- a/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java
@@ -16,13 +16,16 @@
*/
package org.apache.lucene.search;
+import java.io.IOException;
+import org.apache.lucene.index.LeafReaderContext;
+
/**
- * Just counts the total number of hits. For cases when this is the only collector used, {@link
- * IndexSearcher#count(Query)} should be called instead of {@link IndexSearcher#search(Query,
- * Collector)} as the former is faster whenever the count can be returned directly from the index
- * statistics.
+ * Just counts the total number of hits. This is the collector behind {@link IndexSearcher#count}.
+ * When the {@link Weight} implements {@link Weight#count}, this collector will skip collecting
+ * segments.
*/
-public class TotalHitCountCollector extends SimpleCollector {
+public class TotalHitCountCollector implements Collector {
+ private Weight weight;
private int totalHits;
/** Returns how many hits matched the search. */
@@ -31,12 +34,31 @@ public class TotalHitCountCollector extends SimpleCollector {
}
@Override
- public void collect(int doc) {
- totalHits++;
+ public ScoreMode scoreMode() {
+ return ScoreMode.COMPLETE_NO_SCORES;
}
@Override
- public ScoreMode scoreMode() {
- return ScoreMode.COMPLETE_NO_SCORES;
+ public void setWeight(Weight weight) {
+ this.weight = weight;
+ }
+
+ @Override
+ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
+ int leafCount = weight == null ? -1 : weight.count(context);
+ if (leafCount != -1) {
+ totalHits += leafCount;
+ throw new CollectionTerminatedException();
+ }
+ return new LeafCollector() {
+
+ @Override
+ public void setScorer(Scorable scorer) throws IOException {}
+
+ @Override
+ public void collect(int doc) throws IOException {
+ totalHits++;
+ }
+ };
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java
index 6a391a38274..a5710121acf 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java
@@ -45,6 +45,7 @@ import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.search.FixedBitSetCollector;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
@@ -1021,7 +1022,7 @@ public class TestBooleanQuery extends LuceneTestCase {
builder.setMinimumNumberShouldMatch(TestUtil.nextInt(random(), 0, numShouldClauses));
Query booleanQuery = builder.build();
assertEquals(
- (int) searcher.search(booleanQuery, new TotalHitCountCollectorManager()),
+ (int) searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager()),
searcher.count(booleanQuery));
}
reader.close();
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java b/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java
index a30bb757e60..10826517b1d 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java
@@ -64,6 +64,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.AssertingIndexSearcher;
import org.apache.lucene.tests.search.CheckHits;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.RamUsageTester;
import org.apache.lucene.tests.util.TestUtil;
@@ -168,8 +169,8 @@ public class TestLRUQueryCache extends LuceneTestCase {
RandomPicks.randomFrom(
random(), new String[] {"blue", "red", "yellow", "green"});
final Query q = new TermQuery(new Term("color", value));
- TotalHitCountCollectorManager collectorManager =
- new TotalHitCountCollectorManager();
+ CollectorManager<DummyTotalHitCountCollector, Integer> collectorManager =
+ DummyTotalHitCountCollector.createManager();
// will use the cache
final int totalHits1 = searcher.search(q, collectorManager);
final long totalHits2 =
@@ -177,8 +178,8 @@ public class TestLRUQueryCache extends LuceneTestCase {
q,
new CollectorManager<FilterCollector, Integer>() {
@Override
- public FilterCollector newCollector() {
- return new FilterCollector(new TotalHitCountCollector()) {
+ public FilterCollector newCollector() throws IOException {
+ return new FilterCollector(collectorManager.newCollector()) {
@Override
public ScoreMode scoreMode() {
// will not use the cache because of scores
@@ -194,7 +195,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
collectors.stream()
.map(
filterCollector ->
- (TotalHitCountCollector) filterCollector.in)
+ (DummyTotalHitCountCollector) filterCollector.in)
.collect(Collectors.toList()));
}
});
@@ -963,7 +964,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
searcher.setQueryCache(queryCache);
searcher.setQueryCachingPolicy(policy);
- searcher.search(query.build(), new TotalHitCountCollectorManager());
+ searcher.search(query.build(), DummyTotalHitCountCollector.createManager());
reader.close();
dir.close();
@@ -1187,12 +1188,12 @@ public class TestLRUQueryCache extends LuceneTestCase {
searcher.setQueryCachingPolicy(ALWAYS_CACHE);
BadQuery query = new BadQuery();
- searcher.search(query, new TotalHitCountCollectorManager());
+ searcher.search(query, DummyTotalHitCountCollector.createManager());
query.i[0] += 1; // change the hashCode!
try {
// trigger an eviction
- searcher.search(new MatchAllDocsQuery(), new TotalHitCountCollectorManager());
+ searcher.search(new MatchAllDocsQuery(), DummyTotalHitCountCollector.createManager());
fail();
} catch (
@SuppressWarnings("unused")
@@ -1273,7 +1274,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
query.add(bar, Occur.FILTER);
query.add(foo, Occur.FILTER);
}
- indexSearcher.search(query.build(), new TotalHitCountCollectorManager());
+ indexSearcher.search(query.build(), DummyTotalHitCountCollector.createManager());
assertEquals(1, policy.frequency(query.build()));
assertEquals(1, policy.frequency(foo));
assertEquals(1, policy.frequency(bar));
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java
index 9f8a2f5902d..a8fc829bcc1 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java
@@ -32,6 +32,7 @@ import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.junit.Test;
@@ -101,13 +102,13 @@ public class TestMultiCollector extends LuceneTestCase {
final IndexReader reader = w.getReader();
w.close();
final IndexSearcher searcher = newSearcher(reader, true, true, false);
- Map<TotalHitCountCollector, Integer> expectedCounts = new HashMap<>();
+ Map<DummyTotalHitCountCollector, Integer> expectedCounts = new HashMap<>();
List<Collector> collectors = new ArrayList<>();
final int numCollectors = TestUtil.nextInt(random(), 1, 5);
for (int i = 0; i < numCollectors; ++i) {
final int terminateAfter = random().nextInt(numDocs + 10);
final int expectedCount = terminateAfter > numDocs ? numDocs : terminateAfter;
- TotalHitCountCollector collector = new TotalHitCountCollector();
+ DummyTotalHitCountCollector collector = new DummyTotalHitCountCollector();
expectedCounts.put(collector, expectedCount);
collectors.add(new TerminateAfterCollector(collector, terminateAfter));
}
@@ -124,7 +125,8 @@ public class TestMultiCollector extends LuceneTestCase {
return null;
}
});
- for (Map.Entry<TotalHitCountCollector, Integer> expectedCount : expectedCounts.entrySet()) {
+ for (Map.Entry<DummyTotalHitCountCollector, Integer> expectedCount :
+ expectedCounts.entrySet()) {
assertEquals(expectedCount.getValue().intValue(), expectedCount.getKey().getTotalHits());
}
reader.close();
@@ -133,8 +135,8 @@ public class TestMultiCollector extends LuceneTestCase {
}
public void testSetScorerAfterCollectionTerminated() throws IOException {
- Collector collector1 = new TotalHitCountCollector();
- Collector collector2 = new TotalHitCountCollector();
+ Collector collector1 = new DummyTotalHitCountCollector();
+ Collector collector2 = new DummyTotalHitCountCollector();
AtomicBoolean setScorerCalled1 = new AtomicBoolean();
collector1 = new SetScorerCollector(collector1, setScorerCalled1);
@@ -224,7 +226,7 @@ public class TestMultiCollector extends LuceneTestCase {
scorer.setMinCompetitiveScore(minScore);
}
};
- Collector multiCollector = MultiCollector.wrap(collector, new TotalHitCountCollector());
+ Collector multiCollector = MultiCollector.wrap(collector, new DummyTotalHitCountCollector());
LeafCollector leafCollector = multiCollector.getLeafCollector(reader.leaves().get(0));
leafCollector.setScorer(scorer);
leafCollector.collect(0); // no exception
@@ -283,7 +285,7 @@ public class TestMultiCollector extends LuceneTestCase {
List<Collector> cols = new ArrayList<>();
cols.add(collector);
for (int col = 0; col < numCol; col++) {
- cols.add(new TerminateAfterCollector(new TotalHitCountCollector(), 0));
+ cols.add(new TerminateAfterCollector(new DummyTotalHitCountCollector(), 0));
}
Collections.shuffle(cols, random());
Collector multiCollector = MultiCollector.wrap(cols);
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java b/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java
index 17bcce0ee15..d4bd95fca7a 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java
@@ -24,6 +24,7 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestSearchWithThreads extends LuceneTestCase {
@@ -57,7 +58,7 @@ public class TestSearchWithThreads extends LuceneTestCase {
final AtomicBoolean failed = new AtomicBoolean();
final AtomicLong netSearch = new AtomicLong();
- TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
+ CollectorManager<?, Integer> collectorManager = DummyTotalHitCountCollector.createManager();
Thread[] threads = new Thread[numThreads];
for (int threadID = 0; threadID < numThreads; threadID++) {
threads[threadID] =
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java
index 5ad4392355b..9a6f39de726 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java
@@ -34,6 +34,7 @@ import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@@ -91,14 +92,13 @@ public class TestTermQuery extends LuceneTestCase {
IndexSearcher searcher = new IndexSearcher(reader);
// use a collector rather than searcher.count() which would just read the
// doc freq instead of creating a scorer
- TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
- int totalHits = searcher.search(query, collectorManager);
+ int totalHits = searcher.search(query, DummyTotalHitCountCollector.createManager());
assertEquals(1, totalHits);
TermQuery queryWithContext =
new TermQuery(
new Term("foo", "bar"),
TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
- totalHits = searcher.search(queryWithContext, collectorManager);
+ totalHits = searcher.search(queryWithContext, DummyTotalHitCountCollector.createManager());
assertEquals(1, totalHits);
IOUtils.close(reader, w, dir);
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java
index 49049ebd378..eb2afb58e34 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java
@@ -20,6 +20,8 @@ import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
@@ -42,6 +44,15 @@ public class TestTotalHitCountCollector extends LuceneTestCase {
TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
int totalHits = searcher.search(new MatchAllDocsQuery(), collectorManager);
assertEquals(5, totalHits);
+
+ Query query =
+ new BooleanQuery.Builder()
+ .add(new TermQuery(new Term("string", "a1")), Occur.SHOULD)
+ .add(new TermQuery(new Term("string", "b3")), Occur.SHOULD)
+ .build();
+ totalHits = searcher.search(query, collectorManager);
+ assertEquals(2, totalHits);
+
reader.close();
indexStore.close();
}
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java
index b38f7d626f1..94f56b93189 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java
@@ -24,6 +24,7 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Weight;
/**
* This class wraps a Collector and times the execution of: - setScorer() - collect() -
@@ -83,6 +84,11 @@ public class ProfilerCollector implements Collector {
return collector.getLeafCollector(context);
}
+ @Override
+ public void setWeight(Weight weight) {
+ collector.setWeight(weight);
+ }
+
@Override
public ScoreMode scoreMode() {
return collector.scoreMode();
diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java
index c10110af812..4323a005c17 100644
--- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java
+++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java
@@ -40,11 +40,11 @@ import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.SortedNumericSortField;
import org.apache.lucene.search.TopDocs;
-import org.apache.lucene.search.TotalHitCountCollectorManager;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@@ -221,7 +221,8 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas
private static void assertNumberOfHits(IndexSearcher searcher, Query query, int numberOfHits)
throws IOException {
assertEquals(
- numberOfHits, searcher.search(query, new TotalHitCountCollectorManager()).intValue());
+ numberOfHits,
+ searcher.search(query, DummyTotalHitCountCollector.createManager()).intValue());
assertEquals(numberOfHits, searcher.count(query));
}
diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java
index fa420b608a0..03b92418199 100644
--- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java
+++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java
@@ -37,9 +37,9 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
-import org.apache.lucene.search.TotalHitCountCollectorManager;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@@ -808,8 +808,8 @@ public class TestMultiRangeQueries extends LuceneTestCase {
MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader);
BooleanQuery booleanQuery = builder2.build();
- int count = searcher.search(multiRangeQuery, new TotalHitCountCollectorManager());
- int booleanCount = searcher.search(booleanQuery, new TotalHitCountCollectorManager());
+ int count = searcher.search(multiRangeQuery, DummyTotalHitCountCollector.createManager());
+ int booleanCount = searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager());
assertEquals(booleanCount, count);
}
IOUtils.close(reader, w, dir);
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java
index 7ffa9350e2b..cf2c2732614 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java
@@ -22,10 +22,12 @@ import org.apache.lucene.search.Collector;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FilterCollector;
import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.Weight;
/** A collector that asserts that it is used correctly. */
class AssertingCollector extends FilterCollector {
+ private boolean weightSet = false;
private int maxDoc = -1;
private int previousLeafMaxDoc = 0;
@@ -43,6 +45,7 @@ class AssertingCollector extends FilterCollector {
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
+ assert weightSet : "Set the weight first";
assert context.docBase >= previousLeafMaxDoc;
previousLeafMaxDoc = context.docBase + context.reader().maxDoc();
@@ -65,4 +68,12 @@ class AssertingCollector extends FilterCollector {
}
};
}
+
+ @Override
+ public void setWeight(Weight weight) {
+ assert weightSet == false : "Weight set twice";
+ weightSet = true;
+ assert weight != null;
+ in.setWeight(weight);
+ }
}
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/DummyTotalHitCountCollector.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/DummyTotalHitCountCollector.java
new file mode 100644
index 00000000000..fcb53b96f0a
--- /dev/null
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/DummyTotalHitCountCollector.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.tests.search;
+
+import java.io.IOException;
+import java.util.Collection;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.Collector;
+import org.apache.lucene.search.CollectorManager;
+import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.Scorable;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.TotalHitCountCollector;
+import org.apache.lucene.search.Weight;
+
+/**
+ * A dummy version of {@link TotalHitCountCollector} that doesn't shortcut using {@link
+ * Weight#count}.
+ */
+public class DummyTotalHitCountCollector implements Collector {
+ private int totalHits;
+
+ /** Constructor */
+ public DummyTotalHitCountCollector() {}
+
+ /** Get the number of hits. */
+ public int getTotalHits() {
+ return totalHits;
+ }
+
+ @Override
+ public ScoreMode scoreMode() {
+ return ScoreMode.COMPLETE_NO_SCORES;
+ }
+
+ @Override
+ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
+ return new LeafCollector() {
+
+ @Override
+ public void setScorer(Scorable scorer) throws IOException {}
+
+ @Override
+ public void collect(int doc) throws IOException {
+ totalHits++;
+ }
+ };
+ }
+
+ /** Create a collector manager. */
+ public static CollectorManager<DummyTotalHitCountCollector, Integer> createManager() {
+ return new CollectorManager<DummyTotalHitCountCollector, Integer>() {
+
+ @Override
+ public DummyTotalHitCountCollector newCollector() throws IOException {
+ return new DummyTotalHitCountCollector();
+ }
+
+ @Override
+ public Integer reduce(Collection<DummyTotalHitCountCollector> collectors) throws IOException {
+ int sum = 0;
+ for (DummyTotalHitCountCollector coll : collectors) {
+ sum += coll.totalHits;
+ }
+ return sum;
+ }
+ };
+ }
+}