You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2022/06/23 15:56:21 UTC

[lucene] branch main updated: LUCENE-10620: Pass the Weight to Collectors. (#964)

This is an automated email from the ASF dual-hosted git repository.

jpountz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c1ae2a332c LUCENE-10620: Pass the Weight to Collectors. (#964)
4c1ae2a332c is described below

commit 4c1ae2a332cb85878310142ebb9fd5beba0345f2
Author: Adrien Grand <jp...@gmail.com>
AuthorDate: Thu Jun 23 17:56:15 2022 +0200

    LUCENE-10620: Pass the Weight to Collectors. (#964)
    
    This allows `Collector`s to use `Weight#count` when appropriate.
---
 .../java/org/apache/lucene/search/Collector.java   |  7 ++
 .../org/apache/lucene/search/FilterCollector.java  |  5 ++
 .../org/apache/lucene/search/IndexSearcher.java    | 52 +-------------
 .../org/apache/lucene/search/MultiCollector.java   |  7 ++
 .../lucene/search/TotalHitCountCollector.java      | 40 ++++++++---
 .../org/apache/lucene/search/TestBooleanQuery.java |  3 +-
 .../apache/lucene/search/TestLRUQueryCache.java    | 19 ++---
 .../apache/lucene/search/TestMultiCollector.java   | 16 +++--
 .../lucene/search/TestSearchWithThreads.java       |  3 +-
 .../org/apache/lucene/search/TestTermQuery.java    |  6 +-
 .../lucene/search/TestTotalHitCountCollector.java  | 11 +++
 .../lucene/sandbox/search/ProfilerCollector.java   |  6 ++
 ...tIndexSortSortedNumericDocValuesRangeQuery.java |  5 +-
 .../sandbox/search/TestMultiRangeQueries.java      |  6 +-
 .../lucene/tests/search/AssertingCollector.java    | 11 +++
 .../tests/search/DummyTotalHitCountCollector.java  | 83 ++++++++++++++++++++++
 16 files changed, 196 insertions(+), 84 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/search/Collector.java b/lucene/core/src/java/org/apache/lucene/search/Collector.java
index 3af2210bd6e..7c02e446755 100644
--- a/lucene/core/src/java/org/apache/lucene/search/Collector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/Collector.java
@@ -56,4 +56,11 @@ public interface Collector {
 
   /** Indicates what features are required from the scorer. */
   ScoreMode scoreMode();
+
+  /**
+   * Set the {@link Weight} that will be used to produce scorers that will feed {@link
+   * LeafCollector}s. This is typically useful to have access to {@link Weight#count} from {@link
+   * Collector#getLeafCollector}.
+   */
+  default void setWeight(Weight weight) {}
 }
diff --git a/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java b/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java
index 9f57a57b00a..32c52abe1b7 100644
--- a/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/FilterCollector.java
@@ -38,6 +38,11 @@ public abstract class FilterCollector implements Collector {
     return in.getLeafCollector(context);
   }
 
+  @Override
+  public void setWeight(Weight weight) {
+    in.setWeight(weight);
+  }
+
   @Override
   public String toString() {
     return getClass().getSimpleName() + "(" + in + ")";
diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
index 4e64089f3c1..0f6bdfdeb10 100644
--- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
@@ -412,61 +412,13 @@ public class IndexSearcher {
     return similarity;
   }
 
-  private static class ShortcutHitCountCollector implements Collector {
-    private final Weight weight;
-    private final TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
-    private int weightCount;
-
-    ShortcutHitCountCollector(Weight weight) {
-      this.weight = weight;
-    }
-
-    @Override
-    public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
-      int count = weight.count(context);
-      // check if the number of hits can be computed in constant time
-      if (count == -1) {
-        // use a TotalHitCountCollector to calculate the number of hits in the usual way
-        return totalHitCountCollector.getLeafCollector(context);
-      } else {
-        weightCount += count;
-        throw new CollectionTerminatedException();
-      }
-    }
-
-    @Override
-    public ScoreMode scoreMode() {
-      return ScoreMode.COMPLETE_NO_SCORES;
-    }
-  }
-
   /**
    * Count how many documents match the given query. May be faster than counting number of hits by
    * collecting all matches, as the number of hits is retrieved from the index statistics when
    * possible.
    */
   public int count(Query query) throws IOException {
-    query = rewrite(query, false);
-    final Weight weight = createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1);
-
-    final CollectorManager<ShortcutHitCountCollector, Integer> shortcutCollectorManager =
-        new CollectorManager<ShortcutHitCountCollector, Integer>() {
-          @Override
-          public ShortcutHitCountCollector newCollector() throws IOException {
-            return new ShortcutHitCountCollector(weight);
-          }
-
-          @Override
-          public Integer reduce(Collection<ShortcutHitCountCollector> collectors)
-              throws IOException {
-            int totalHitCount = 0;
-            for (ShortcutHitCountCollector c : collectors) {
-              totalHitCount += c.weightCount + c.totalHitCountCollector.getTotalHits();
-            }
-            return totalHitCount;
-          }
-        };
-    return search(weight, shortcutCollectorManager, new ShortcutHitCountCollector(weight));
+    return search(new ConstantScoreQuery(query), new TotalHitCountCollectorManager());
   }
 
   /**
@@ -750,6 +702,8 @@ public class IndexSearcher {
   protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector)
       throws IOException {
 
+    collector.setWeight(weight);
+
     // TODO: should we make this
     // threaded...? the Collector could be sync'd?
     // always use single thread:
diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java b/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
index 09aea3a02a3..5452c0f8d69 100644
--- a/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
@@ -149,6 +149,13 @@ public class MultiCollector implements Collector {
     }
   }
 
+  @Override
+  public void setWeight(Weight weight) {
+    for (Collector collector : collectors) {
+      collector.setWeight(weight);
+    }
+  }
+
   /** Provides access to the wrapped {@code Collector}s for advanced use-cases */
   public Collector[] getCollectors() {
     return collectors;
diff --git a/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java b/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java
index 9d9ad4149b0..30d0659f2cd 100644
--- a/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java
@@ -16,13 +16,16 @@
  */
 package org.apache.lucene.search;
 
+import java.io.IOException;
+import org.apache.lucene.index.LeafReaderContext;
+
 /**
- * Just counts the total number of hits. For cases when this is the only collector used, {@link
- * IndexSearcher#count(Query)} should be called instead of {@link IndexSearcher#search(Query,
- * Collector)} as the former is faster whenever the count can be returned directly from the index
- * statistics.
+ * Just counts the total number of hits. This is the collector behind {@link IndexSearcher#count}.
+ * When the {@link Weight} implements {@link Weight#count}, this collector will skip collecting
+ * segments.
  */
-public class TotalHitCountCollector extends SimpleCollector {
+public class TotalHitCountCollector implements Collector {
+  private Weight weight;
   private int totalHits;
 
   /** Returns how many hits matched the search. */
@@ -31,12 +34,31 @@ public class TotalHitCountCollector extends SimpleCollector {
   }
 
   @Override
-  public void collect(int doc) {
-    totalHits++;
+  public ScoreMode scoreMode() {
+    return ScoreMode.COMPLETE_NO_SCORES;
   }
 
   @Override
-  public ScoreMode scoreMode() {
-    return ScoreMode.COMPLETE_NO_SCORES;
+  public void setWeight(Weight weight) {
+    this.weight = weight;
+  }
+
+  @Override
+  public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
+    int leafCount = weight == null ? -1 : weight.count(context);
+    if (leafCount != -1) {
+      totalHits += leafCount;
+      throw new CollectionTerminatedException();
+    }
+    return new LeafCollector() {
+
+      @Override
+      public void setScorer(Scorable scorer) throws IOException {}
+
+      @Override
+      public void collect(int doc) throws IOException {
+        totalHits++;
+      }
+    };
   }
 }
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java
index 6a391a38274..a5710121acf 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanQuery.java
@@ -45,6 +45,7 @@ import org.apache.lucene.search.similarities.ClassicSimilarity;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.analysis.MockAnalyzer;
 import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
 import org.apache.lucene.tests.search.FixedBitSetCollector;
 import org.apache.lucene.tests.search.QueryUtils;
 import org.apache.lucene.tests.util.LuceneTestCase;
@@ -1021,7 +1022,7 @@ public class TestBooleanQuery extends LuceneTestCase {
       builder.setMinimumNumberShouldMatch(TestUtil.nextInt(random(), 0, numShouldClauses));
       Query booleanQuery = builder.build();
       assertEquals(
-          (int) searcher.search(booleanQuery, new TotalHitCountCollectorManager()),
+          (int) searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager()),
           searcher.count(booleanQuery));
     }
     reader.close();
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java b/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java
index a30bb757e60..10826517b1d 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestLRUQueryCache.java
@@ -64,6 +64,7 @@ import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
 import org.apache.lucene.tests.search.AssertingIndexSearcher;
 import org.apache.lucene.tests.search.CheckHits;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
 import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.tests.util.RamUsageTester;
 import org.apache.lucene.tests.util.TestUtil;
@@ -168,8 +169,8 @@ public class TestLRUQueryCache extends LuceneTestCase {
                         RandomPicks.randomFrom(
                             random(), new String[] {"blue", "red", "yellow", "green"});
                     final Query q = new TermQuery(new Term("color", value));
-                    TotalHitCountCollectorManager collectorManager =
-                        new TotalHitCountCollectorManager();
+                    CollectorManager<DummyTotalHitCountCollector, Integer> collectorManager =
+                        DummyTotalHitCountCollector.createManager();
                     // will use the cache
                     final int totalHits1 = searcher.search(q, collectorManager);
                     final long totalHits2 =
@@ -177,8 +178,8 @@ public class TestLRUQueryCache extends LuceneTestCase {
                             q,
                             new CollectorManager<FilterCollector, Integer>() {
                               @Override
-                              public FilterCollector newCollector() {
-                                return new FilterCollector(new TotalHitCountCollector()) {
+                              public FilterCollector newCollector() throws IOException {
+                                return new FilterCollector(collectorManager.newCollector()) {
                                   @Override
                                   public ScoreMode scoreMode() {
                                     // will not use the cache because of scores
@@ -194,7 +195,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
                                     collectors.stream()
                                         .map(
                                             filterCollector ->
-                                                (TotalHitCountCollector) filterCollector.in)
+                                                (DummyTotalHitCountCollector) filterCollector.in)
                                         .collect(Collectors.toList()));
                               }
                             });
@@ -963,7 +964,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
 
     searcher.setQueryCache(queryCache);
     searcher.setQueryCachingPolicy(policy);
-    searcher.search(query.build(), new TotalHitCountCollectorManager());
+    searcher.search(query.build(), DummyTotalHitCountCollector.createManager());
 
     reader.close();
     dir.close();
@@ -1187,12 +1188,12 @@ public class TestLRUQueryCache extends LuceneTestCase {
     searcher.setQueryCachingPolicy(ALWAYS_CACHE);
 
     BadQuery query = new BadQuery();
-    searcher.search(query, new TotalHitCountCollectorManager());
+    searcher.search(query, DummyTotalHitCountCollector.createManager());
     query.i[0] += 1; // change the hashCode!
 
     try {
       // trigger an eviction
-      searcher.search(new MatchAllDocsQuery(), new TotalHitCountCollectorManager());
+      searcher.search(new MatchAllDocsQuery(), DummyTotalHitCountCollector.createManager());
       fail();
     } catch (
         @SuppressWarnings("unused")
@@ -1273,7 +1274,7 @@ public class TestLRUQueryCache extends LuceneTestCase {
           query.add(bar, Occur.FILTER);
           query.add(foo, Occur.FILTER);
         }
-        indexSearcher.search(query.build(), new TotalHitCountCollectorManager());
+        indexSearcher.search(query.build(), DummyTotalHitCountCollector.createManager());
         assertEquals(1, policy.frequency(query.build()));
         assertEquals(1, policy.frequency(foo));
         assertEquals(1, policy.frequency(bar));
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java
index 9f8a2f5902d..a8fc829bcc1 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestMultiCollector.java
@@ -32,6 +32,7 @@ import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
 import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.tests.util.TestUtil;
 import org.junit.Test;
@@ -101,13 +102,13 @@ public class TestMultiCollector extends LuceneTestCase {
       final IndexReader reader = w.getReader();
       w.close();
       final IndexSearcher searcher = newSearcher(reader, true, true, false);
-      Map<TotalHitCountCollector, Integer> expectedCounts = new HashMap<>();
+      Map<DummyTotalHitCountCollector, Integer> expectedCounts = new HashMap<>();
       List<Collector> collectors = new ArrayList<>();
       final int numCollectors = TestUtil.nextInt(random(), 1, 5);
       for (int i = 0; i < numCollectors; ++i) {
         final int terminateAfter = random().nextInt(numDocs + 10);
         final int expectedCount = terminateAfter > numDocs ? numDocs : terminateAfter;
-        TotalHitCountCollector collector = new TotalHitCountCollector();
+        DummyTotalHitCountCollector collector = new DummyTotalHitCountCollector();
         expectedCounts.put(collector, expectedCount);
         collectors.add(new TerminateAfterCollector(collector, terminateAfter));
       }
@@ -124,7 +125,8 @@ public class TestMultiCollector extends LuceneTestCase {
               return null;
             }
           });
-      for (Map.Entry<TotalHitCountCollector, Integer> expectedCount : expectedCounts.entrySet()) {
+      for (Map.Entry<DummyTotalHitCountCollector, Integer> expectedCount :
+          expectedCounts.entrySet()) {
         assertEquals(expectedCount.getValue().intValue(), expectedCount.getKey().getTotalHits());
       }
       reader.close();
@@ -133,8 +135,8 @@ public class TestMultiCollector extends LuceneTestCase {
   }
 
   public void testSetScorerAfterCollectionTerminated() throws IOException {
-    Collector collector1 = new TotalHitCountCollector();
-    Collector collector2 = new TotalHitCountCollector();
+    Collector collector1 = new DummyTotalHitCountCollector();
+    Collector collector2 = new DummyTotalHitCountCollector();
 
     AtomicBoolean setScorerCalled1 = new AtomicBoolean();
     collector1 = new SetScorerCollector(collector1, setScorerCalled1);
@@ -224,7 +226,7 @@ public class TestMultiCollector extends LuceneTestCase {
             scorer.setMinCompetitiveScore(minScore);
           }
         };
-    Collector multiCollector = MultiCollector.wrap(collector, new TotalHitCountCollector());
+    Collector multiCollector = MultiCollector.wrap(collector, new DummyTotalHitCountCollector());
     LeafCollector leafCollector = multiCollector.getLeafCollector(reader.leaves().get(0));
     leafCollector.setScorer(scorer);
     leafCollector.collect(0); // no exception
@@ -283,7 +285,7 @@ public class TestMultiCollector extends LuceneTestCase {
       List<Collector> cols = new ArrayList<>();
       cols.add(collector);
       for (int col = 0; col < numCol; col++) {
-        cols.add(new TerminateAfterCollector(new TotalHitCountCollector(), 0));
+        cols.add(new TerminateAfterCollector(new DummyTotalHitCountCollector(), 0));
       }
       Collections.shuffle(cols, random());
       Collector multiCollector = MultiCollector.wrap(cols);
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java b/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java
index 17bcce0ee15..d4bd95fca7a 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestSearchWithThreads.java
@@ -24,6 +24,7 @@ import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
 import org.apache.lucene.tests.util.LuceneTestCase;
 
 public class TestSearchWithThreads extends LuceneTestCase {
@@ -57,7 +58,7 @@ public class TestSearchWithThreads extends LuceneTestCase {
 
     final AtomicBoolean failed = new AtomicBoolean();
     final AtomicLong netSearch = new AtomicLong();
-    TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
+    CollectorManager<?, Integer> collectorManager = DummyTotalHitCountCollector.createManager();
     Thread[] threads = new Thread[numThreads];
     for (int threadID = 0; threadID < numThreads; threadID++) {
       threads[threadID] =
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java
index 5ad4392355b..9a6f39de726 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java
@@ -34,6 +34,7 @@ import org.apache.lucene.index.Terms;
 import org.apache.lucene.index.TermsEnum;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
 import org.apache.lucene.tests.search.QueryUtils;
 import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.tests.util.TestUtil;
@@ -91,14 +92,13 @@ public class TestTermQuery extends LuceneTestCase {
     IndexSearcher searcher = new IndexSearcher(reader);
     // use a collector rather than searcher.count() which would just read the
     // doc freq instead of creating a scorer
-    TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
-    int totalHits = searcher.search(query, collectorManager);
+    int totalHits = searcher.search(query, DummyTotalHitCountCollector.createManager());
     assertEquals(1, totalHits);
     TermQuery queryWithContext =
         new TermQuery(
             new Term("foo", "bar"),
             TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
-    totalHits = searcher.search(queryWithContext, collectorManager);
+    totalHits = searcher.search(queryWithContext, DummyTotalHitCountCollector.createManager());
     assertEquals(1, totalHits);
 
     IOUtils.close(reader, w, dir);
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java
index 49049ebd378..eb2afb58e34 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestTotalHitCountCollector.java
@@ -20,6 +20,8 @@ import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
 import org.apache.lucene.document.StringField;
 import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.BooleanClause.Occur;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
 import org.apache.lucene.tests.util.LuceneTestCase;
@@ -42,6 +44,15 @@ public class TestTotalHitCountCollector extends LuceneTestCase {
     TotalHitCountCollectorManager collectorManager = new TotalHitCountCollectorManager();
     int totalHits = searcher.search(new MatchAllDocsQuery(), collectorManager);
     assertEquals(5, totalHits);
+
+    Query query =
+        new BooleanQuery.Builder()
+            .add(new TermQuery(new Term("string", "a1")), Occur.SHOULD)
+            .add(new TermQuery(new Term("string", "b3")), Occur.SHOULD)
+            .build();
+    totalHits = searcher.search(query, collectorManager);
+    assertEquals(2, totalHits);
+
     reader.close();
     indexStore.close();
   }
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java
index b38f7d626f1..94f56b93189 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ProfilerCollector.java
@@ -24,6 +24,7 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.Collector;
 import org.apache.lucene.search.LeafCollector;
 import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Weight;
 
 /**
  * This class wraps a Collector and times the execution of: - setScorer() - collect() -
@@ -83,6 +84,11 @@ public class ProfilerCollector implements Collector {
     return collector.getLeafCollector(context);
   }
 
+  @Override
+  public void setWeight(Weight weight) {
+    collector.setWeight(weight);
+  }
+
   @Override
   public ScoreMode scoreMode() {
     return collector.scoreMode();
diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java
index c10110af812..4323a005c17 100644
--- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java
+++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java
@@ -40,11 +40,11 @@ import org.apache.lucene.search.Sort;
 import org.apache.lucene.search.SortField;
 import org.apache.lucene.search.SortedNumericSortField;
 import org.apache.lucene.search.TopDocs;
-import org.apache.lucene.search.TotalHitCountCollectorManager;
 import org.apache.lucene.search.Weight;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.analysis.MockAnalyzer;
 import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
 import org.apache.lucene.tests.search.QueryUtils;
 import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.tests.util.TestUtil;
@@ -221,7 +221,8 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas
   private static void assertNumberOfHits(IndexSearcher searcher, Query query, int numberOfHits)
       throws IOException {
     assertEquals(
-        numberOfHits, searcher.search(query, new TotalHitCountCollectorManager()).intValue());
+        numberOfHits,
+        searcher.search(query, DummyTotalHitCountCollector.createManager()).intValue());
     assertEquals(numberOfHits, searcher.count(query));
   }
 
diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java
index fa420b608a0..03b92418199 100644
--- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java
+++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java
@@ -37,9 +37,9 @@ import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.Sort;
 import org.apache.lucene.search.TopDocs;
-import org.apache.lucene.search.TotalHitCountCollectorManager;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.search.DummyTotalHitCountCollector;
 import org.apache.lucene.tests.search.QueryUtils;
 import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.tests.util.TestUtil;
@@ -808,8 +808,8 @@ public class TestMultiRangeQueries extends LuceneTestCase {
 
       MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader);
       BooleanQuery booleanQuery = builder2.build();
-      int count = searcher.search(multiRangeQuery, new TotalHitCountCollectorManager());
-      int booleanCount = searcher.search(booleanQuery, new TotalHitCountCollectorManager());
+      int count = searcher.search(multiRangeQuery, DummyTotalHitCountCollector.createManager());
+      int booleanCount = searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager());
       assertEquals(booleanCount, count);
     }
     IOUtils.close(reader, w, dir);
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java
index 7ffa9350e2b..cf2c2732614 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java
@@ -22,10 +22,12 @@ import org.apache.lucene.search.Collector;
 import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.search.FilterCollector;
 import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.Weight;
 
 /** A collector that asserts that it is used correctly. */
 class AssertingCollector extends FilterCollector {
 
+  private boolean weightSet = false;
   private int maxDoc = -1;
   private int previousLeafMaxDoc = 0;
 
@@ -43,6 +45,7 @@ class AssertingCollector extends FilterCollector {
 
   @Override
   public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
+    assert weightSet : "Set the weight first";
     assert context.docBase >= previousLeafMaxDoc;
     previousLeafMaxDoc = context.docBase + context.reader().maxDoc();
 
@@ -65,4 +68,12 @@ class AssertingCollector extends FilterCollector {
       }
     };
   }
+
+  @Override
+  public void setWeight(Weight weight) {
+    assert weightSet == false : "Weight set twice";
+    weightSet = true;
+    assert weight != null;
+    in.setWeight(weight);
+  }
 }
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/DummyTotalHitCountCollector.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/DummyTotalHitCountCollector.java
new file mode 100644
index 00000000000..fcb53b96f0a
--- /dev/null
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/DummyTotalHitCountCollector.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.tests.search;
+
+import java.io.IOException;
+import java.util.Collection;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.Collector;
+import org.apache.lucene.search.CollectorManager;
+import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.Scorable;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.TotalHitCountCollector;
+import org.apache.lucene.search.Weight;
+
+/**
+ * A dummy version of {@link TotalHitCountCollector} that doesn't shortcut using {@link
+ * Weight#count}.
+ */
+public class DummyTotalHitCountCollector implements Collector {
+  private int totalHits;
+
+  /** Constructor */
+  public DummyTotalHitCountCollector() {}
+
+  /** Get the number of hits. */
+  public int getTotalHits() {
+    return totalHits;
+  }
+
+  @Override
+  public ScoreMode scoreMode() {
+    return ScoreMode.COMPLETE_NO_SCORES;
+  }
+
+  @Override
+  public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
+    return new LeafCollector() {
+
+      @Override
+      public void setScorer(Scorable scorer) throws IOException {}
+
+      @Override
+      public void collect(int doc) throws IOException {
+        totalHits++;
+      }
+    };
+  }
+
+  /** Create a collector manager. */
+  public static CollectorManager<DummyTotalHitCountCollector, Integer> createManager() {
+    return new CollectorManager<DummyTotalHitCountCollector, Integer>() {
+
+      @Override
+      public DummyTotalHitCountCollector newCollector() throws IOException {
+        return new DummyTotalHitCountCollector();
+      }
+
+      @Override
+      public Integer reduce(Collection<DummyTotalHitCountCollector> collectors) throws IOException {
+        int sum = 0;
+        for (DummyTotalHitCountCollector coll : collectors) {
+          sum += coll.totalHits;
+        }
+        return sum;
+      }
+    };
+  }
+}