You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2023/06/30 13:37:38 UTC

[lucene] 04/04: Add a post-collection hook to LeafCollector. (#12380)

This is an automated email from the ASF dual-hosted git repository.

jpountz pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git

commit de07bdd88dd124da1da8709e337f2b5839f6207d
Author: Adrien Grand <jp...@gmail.com>
AuthorDate: Fri Jun 30 15:19:35 2023 +0200

    Add a post-collection hook to LeafCollector. (#12380)
    
    This adds `LeafCollector#finish` as a per-segment post-collection hook. While
    it was already possible to do this sort of things on top of the collector API
    before, a downside is that the last leaf would need to be post-collected in the
    current thread instead of using the executor, which is a missed opportunity for
    making queries concurrent.
---
 lucene/CHANGES.txt                                 |  4 +-
 .../org/apache/lucene/search/CachingCollector.java | 71 +++++++++++-----------
 .../apache/lucene/search/FilterLeafCollector.java  |  5 ++
 .../org/apache/lucene/search/IndexSearcher.java    |  3 +
 .../org/apache/lucene/search/LeafCollector.java    |  9 +++
 .../org/apache/lucene/search/MultiCollector.java   | 10 +++
 .../apache/lucene/search/TestCachingCollector.java |  2 +
 .../apache/lucene/facet/DrillSidewaysScorer.java   | 17 ++++++
 .../org/apache/lucene/facet/FacetsCollector.java   | 19 +++---
 .../search/grouping/BlockGroupingCollector.java    | 13 ++--
 .../search/grouping/GroupFacetCollector.java       | 11 ++--
 .../search/grouping/TermGroupFacetCollector.java   |  8 ---
 .../suggest/document/SuggestIndexSearcher.java     |  5 +-
 .../suggest/document/TopSuggestDocsCollector.java  | 19 +++---
 .../lucene/tests/search/AssertingCollector.java    | 13 +++-
 .../tests/search/AssertingIndexSearcher.java       |  4 +-
 .../tests/search/AssertingLeafCollector.java       |  8 +++
 17 files changed, 141 insertions(+), 80 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index ef5c9e496a1..b931b32eda2 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -11,7 +11,9 @@ API Changes
 
 New Features
 ---------------------
-(No changes)
+
+* GITHUB#12383: Introduced LeafCollector#finish, a hook that runs after
+  collection has finished running on a leaf. (Adrien Grand)
 
 Improvements
 ---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/search/CachingCollector.java b/lucene/core/src/java/org/apache/lucene/search/CachingCollector.java
index d065efd3406..fd45666d33a 100644
--- a/lucene/core/src/java/org/apache/lucene/search/CachingCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/CachingCollector.java
@@ -66,7 +66,6 @@ public abstract class CachingCollector extends FilterCollector {
     List<LeafReaderContext> contexts;
     List<int[]> docs;
     int maxDocsToCache;
-    NoScoreCachingLeafCollector lastCollector;
 
     NoScoreCachingCollector(Collector in, int maxDocsToCache) {
       super(in);
@@ -76,7 +75,7 @@ public abstract class CachingCollector extends FilterCollector {
     }
 
     protected NoScoreCachingLeafCollector wrap(LeafCollector in, int maxDocsToCache) {
-      return new NoScoreCachingLeafCollector(in, maxDocsToCache);
+      return new NoScoreCachingLeafCollector(in, maxDocsToCache, this);
     }
 
     // note: do *not* override needScore to say false. Just because we aren't caching the score
@@ -85,13 +84,12 @@ public abstract class CachingCollector extends FilterCollector {
 
     @Override
     public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
-      postCollection();
       final LeafCollector in = this.in.getLeafCollector(context);
-      if (contexts != null) {
-        contexts.add(context);
-      }
       if (maxDocsToCache >= 0) {
-        return lastCollector = wrap(in, maxDocsToCache);
+        if (contexts != null) {
+          contexts.add(context);
+        }
+        return wrap(in, maxDocsToCache);
       } else {
         return in;
       }
@@ -103,33 +101,16 @@ public abstract class CachingCollector extends FilterCollector {
       this.docs = null;
     }
 
-    protected void postCollect(NoScoreCachingLeafCollector collector) {
-      final int[] docs = collector.cachedDocs();
-      maxDocsToCache -= docs.length;
-      this.docs.add(docs);
-    }
-
-    private void postCollection() {
-      if (lastCollector != null) {
-        if (!lastCollector.hasCache()) {
-          invalidate();
-        } else {
-          postCollect(lastCollector);
-        }
-        lastCollector = null;
-      }
-    }
-
     protected void collect(LeafCollector collector, int i) throws IOException {
       final int[] docs = this.docs.get(i);
       for (int doc : docs) {
         collector.collect(doc);
       }
+      collector.finish();
     }
 
     @Override
     public void replay(Collector other) throws IOException {
-      postCollection();
       if (!isCached()) {
         throw new IllegalStateException(
             "cannot replay: cache was cleared because too much RAM was required");
@@ -154,14 +135,7 @@ public abstract class CachingCollector extends FilterCollector {
 
     @Override
     protected NoScoreCachingLeafCollector wrap(LeafCollector in, int maxDocsToCache) {
-      return new ScoreCachingLeafCollector(in, maxDocsToCache);
-    }
-
-    @Override
-    protected void postCollect(NoScoreCachingLeafCollector collector) {
-      final ScoreCachingLeafCollector coll = (ScoreCachingLeafCollector) collector;
-      super.postCollect(coll);
-      scores.add(coll.cachedScores());
+      return new ScoreCachingLeafCollector(in, maxDocsToCache, this);
     }
 
     /**
@@ -191,12 +165,15 @@ public abstract class CachingCollector extends FilterCollector {
   private class NoScoreCachingLeafCollector extends FilterLeafCollector {
 
     final int maxDocsToCache;
+    final NoScoreCachingCollector collector;
     int[] docs;
     int docCount;
 
-    NoScoreCachingLeafCollector(LeafCollector in, int maxDocsToCache) {
+    NoScoreCachingLeafCollector(
+        LeafCollector in, int maxDocsToCache, NoScoreCachingCollector collector) {
       super(in);
       this.maxDocsToCache = maxDocsToCache;
+      this.collector = collector;
       docs = new int[Math.min(maxDocsToCache, INITIAL_ARRAY_SIZE)];
       docCount = 0;
     }
@@ -235,6 +212,21 @@ public abstract class CachingCollector extends FilterCollector {
       super.collect(doc);
     }
 
+    protected void postCollect() {
+      final int[] docs = cachedDocs();
+      collector.maxDocsToCache -= docs.length;
+      collector.docs.add(docs);
+    }
+
+    @Override
+    public void finish() {
+      if (!hasCache()) {
+        collector.invalidate();
+      } else {
+        postCollect();
+      }
+    }
+
     boolean hasCache() {
       return docs != null;
     }
@@ -249,8 +241,9 @@ public abstract class CachingCollector extends FilterCollector {
     Scorable scorer;
     float[] scores;
 
-    ScoreCachingLeafCollector(LeafCollector in, int maxDocsToCache) {
-      super(in, maxDocsToCache);
+    ScoreCachingLeafCollector(
+        LeafCollector in, int maxDocsToCache, ScoreCachingCollector collector) {
+      super(in, maxDocsToCache, collector);
       scores = new float[docs.length];
     }
 
@@ -281,6 +274,12 @@ public abstract class CachingCollector extends FilterCollector {
     float[] cachedScores() {
       return docs == null ? null : ArrayUtil.copyOfSubArray(scores, 0, docCount);
     }
+
+    @Override
+    protected void postCollect() {
+      super.postCollect();
+      ((ScoreCachingCollector) collector).scores.add(cachedScores());
+    }
   }
 
   /**
diff --git a/lucene/core/src/java/org/apache/lucene/search/FilterLeafCollector.java b/lucene/core/src/java/org/apache/lucene/search/FilterLeafCollector.java
index 24733668ff6..d9bc671fd42 100644
--- a/lucene/core/src/java/org/apache/lucene/search/FilterLeafCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/FilterLeafCollector.java
@@ -42,6 +42,11 @@ public abstract class FilterLeafCollector implements LeafCollector {
     in.collect(doc);
   }
 
+  @Override
+  public void finish() throws IOException {
+    in.finish();
+  }
+
   @Override
   public String toString() {
     String name = getClass().getSimpleName();
diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
index 9c0a29c052f..b76bf3161d0 100644
--- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
@@ -782,6 +782,9 @@ public class IndexSearcher {
           partialResult = true;
         }
       }
+      // Note: this is called if collection ran successfully, including the above special cases of
+      // CollectionTerminatedException and TimeExceededException, but no other exception.
+      leafCollector.finish();
     }
   }
 
diff --git a/lucene/core/src/java/org/apache/lucene/search/LeafCollector.java b/lucene/core/src/java/org/apache/lucene/search/LeafCollector.java
index a42d531c2b2..334afc798ca 100644
--- a/lucene/core/src/java/org/apache/lucene/search/LeafCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/LeafCollector.java
@@ -95,4 +95,13 @@ public interface LeafCollector {
   default DocIdSetIterator competitiveIterator() throws IOException {
     return null;
   }
+
+  /**
+   * Hook that gets called once the leaf that is associated with this collector has finished
+   * collecting successfully, including when a {@link CollectionTerminatedException} is thrown. This
+   * is typically useful to compile data that has been collected on this leaf, e.g. to convert facet
+   * counts on leaf ordinals to facet counts on global ordinals. The default implementation does
+   * nothing.
+   */
+  default void finish() throws IOException {}
 }
diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java b/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
index 7f8e8121734..ff6a6a97ba5 100644
--- a/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
@@ -223,6 +223,7 @@ public class MultiCollector implements Collector {
           } catch (
               @SuppressWarnings("unused")
               CollectionTerminatedException e) {
+            collectors[i].finish();
             collectors[i] = null;
             if (allCollectorsTerminated()) {
               throw new CollectionTerminatedException();
@@ -232,6 +233,15 @@ public class MultiCollector implements Collector {
       }
     }
 
+    @Override
+    public void finish() throws IOException {
+      for (LeafCollector collector : collectors) {
+        if (collector != null) {
+          collector.finish();
+        }
+      }
+    }
+
     private boolean allCollectorsTerminated() {
       for (int i = 0; i < collectors.length; i++) {
         if (collectors[i] != null) {
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestCachingCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestCachingCollector.java
index 4b07e06dd0a..425c33a34e7 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestCachingCollector.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestCachingCollector.java
@@ -57,6 +57,7 @@ public class TestCachingCollector extends LuceneTestCase {
       for (int i = 0; i < 1000; i++) {
         acc.collect(i);
       }
+      acc.finish();
 
       // now replay them
       cc.replay(
@@ -127,6 +128,7 @@ public class TestCachingCollector extends LuceneTestCase {
       acc.collect(0);
 
       assertTrue(cc.isCached());
+      acc.finish();
       cc.replay(new NoOpCollector());
     }
   }
diff --git a/lucene/facet/src/java/org/apache/lucene/facet/DrillSidewaysScorer.java b/lucene/facet/src/java/org/apache/lucene/facet/DrillSidewaysScorer.java
index c86432b30f7..ad82594f41d 100644
--- a/lucene/facet/src/java/org/apache/lucene/facet/DrillSidewaysScorer.java
+++ b/lucene/facet/src/java/org/apache/lucene/facet/DrillSidewaysScorer.java
@@ -18,6 +18,7 @@ package org.apache.lucene.facet;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
@@ -198,6 +199,7 @@ class DrillSidewaysScorer extends BulkScorer {
 
       docID = baseApproximation.nextDoc();
     }
+    finish(collector, Collections.singleton(dim));
   }
 
   /**
@@ -334,6 +336,8 @@ class DrillSidewaysScorer extends BulkScorer {
 
       docID = baseApproximation.nextDoc();
     }
+
+    finish(collector, sidewaysDims);
   }
 
   private static int advanceIfBehind(int docID, DocIdSetIterator iterator) throws IOException {
@@ -552,6 +556,7 @@ class DrillSidewaysScorer extends BulkScorer {
 
       nextChunkStart += CHUNK;
     }
+    finish(collector, Arrays.asList(dims));
   }
 
   private void doUnionScoring(Bits acceptDocs, LeafCollector collector, DocsAndCost[] dims)
@@ -706,6 +711,8 @@ class DrillSidewaysScorer extends BulkScorer {
 
       nextChunkStart += CHUNK;
     }
+
+    finish(collector, Arrays.asList(dims));
   }
 
   private void collectHit(LeafCollector collector, DocsAndCost[] dims) throws IOException {
@@ -757,6 +764,16 @@ class DrillSidewaysScorer extends BulkScorer {
     sidewaysCollector.collect(collectDocID);
   }
 
+  private void finish(LeafCollector collector, Collection<DocsAndCost> dims) throws IOException {
+    collector.finish();
+    if (drillDownLeafCollector != null) {
+      drillDownLeafCollector.finish();
+    }
+    for (DocsAndCost dim : dims) {
+      dim.sidewaysLeafCollector.finish();
+    }
+  }
+
   private void setScorer(LeafCollector mainCollector, Scorable scorer) throws IOException {
     mainCollector.setScorer(scorer);
     if (drillDownLeafCollector != null) {
diff --git a/lucene/facet/src/java/org/apache/lucene/facet/FacetsCollector.java b/lucene/facet/src/java/org/apache/lucene/facet/FacetsCollector.java
index cf49aef39ee..2bce78e22b3 100644
--- a/lucene/facet/src/java/org/apache/lucene/facet/FacetsCollector.java
+++ b/lucene/facet/src/java/org/apache/lucene/facet/FacetsCollector.java
@@ -103,13 +103,6 @@ public class FacetsCollector extends SimpleCollector {
 
   /** Returns the documents matched by the query, one {@link MatchingDocs} per visited segment. */
   public List<MatchingDocs> getMatchingDocs() {
-    if (docsBuilder != null) {
-      matchingDocs.add(new MatchingDocs(this.context, docsBuilder.build(), totalHits, scores));
-      docsBuilder = null;
-      scores = null;
-      context = null;
-    }
-
     return matchingDocs;
   }
 
@@ -139,9 +132,7 @@ public class FacetsCollector extends SimpleCollector {
 
   @Override
   protected void doSetNextReader(LeafReaderContext context) throws IOException {
-    if (docsBuilder != null) {
-      matchingDocs.add(new MatchingDocs(this.context, docsBuilder.build(), totalHits, scores));
-    }
+    assert docsBuilder == null;
     docsBuilder = new DocIdSetBuilder(context.reader().maxDoc());
     totalHits = 0;
     if (keepScores) {
@@ -150,6 +141,14 @@ public class FacetsCollector extends SimpleCollector {
     this.context = context;
   }
 
+  @Override
+  public void finish() throws IOException {
+    matchingDocs.add(new MatchingDocs(this.context, docsBuilder.build(), totalHits, scores));
+    docsBuilder = null;
+    scores = null;
+    context = null;
+  }
+
   /** Utility method, to search and also collect all hits into the provided {@link Collector}. */
   public static TopDocs search(IndexSearcher searcher, Query q, int n, Collector fc)
       throws IOException {
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
index 9ead686831e..26c3c915dd3 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/BlockGroupingCollector.java
@@ -270,9 +270,6 @@ public class BlockGroupingCollector extends SimpleCollector {
     // if (queueFull) {
     // System.out.println("getTopGroups groupOffset=" + groupOffset + " topNGroups=" + topNGroups);
     // }
-    if (subDocUpto != 0) {
-      processGroup();
-    }
     if (groupOffset >= groupQueue.size()) {
       return null;
     }
@@ -472,9 +469,6 @@ public class BlockGroupingCollector extends SimpleCollector {
 
   @Override
   protected void doSetNextReader(LeafReaderContext readerContext) throws IOException {
-    if (subDocUpto != 0) {
-      processGroup();
-    }
     subDocUpto = 0;
     docBase = readerContext.docBase;
     // System.out.println("setNextReader base=" + docBase + " r=" + readerContext.reader);
@@ -492,6 +486,13 @@ public class BlockGroupingCollector extends SimpleCollector {
     }
   }
 
+  @Override
+  public void finish() throws IOException {
+    if (subDocUpto != 0) {
+      processGroup();
+    }
+  }
+
   @Override
   public ScoreMode scoreMode() {
     return needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupFacetCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupFacetCollector.java
index 74a957d809d..4e56a12c902 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupFacetCollector.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupFacetCollector.java
@@ -67,11 +67,6 @@ public abstract class GroupFacetCollector extends SimpleCollector {
    */
   public GroupedFacetResult mergeSegmentResults(int size, int minCount, boolean orderByCount)
       throws IOException {
-    if (segmentFacetCounts != null) {
-      segmentResults.add(createSegmentResult());
-      segmentFacetCounts = null; // reset
-    }
-
     int totalCount = 0;
     int missingCount = 0;
     SegmentResultPriorityQueue segments = new SegmentResultPriorityQueue(segmentResults.size());
@@ -109,6 +104,12 @@ public abstract class GroupFacetCollector extends SimpleCollector {
     return facetResult;
   }
 
+  @Override
+  public void finish() throws IOException {
+    segmentResults.add(createSegmentResult());
+    segmentFacetCounts = null;
+  }
+
   protected abstract SegmentResult createSegmentResult() throws IOException;
 
   @Override
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TermGroupFacetCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TermGroupFacetCollector.java
index c1b49758b13..e49e517faa8 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/TermGroupFacetCollector.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/TermGroupFacetCollector.java
@@ -141,10 +141,6 @@ public abstract class TermGroupFacetCollector extends GroupFacetCollector {
 
     @Override
     protected void doSetNextReader(LeafReaderContext context) throws IOException {
-      if (segmentFacetCounts != null) {
-        segmentResults.add(createSegmentResult());
-      }
-
       groupFieldTermsIndex = DocValues.getSorted(context.reader(), groupField);
       facetFieldTermsIndex = DocValues.getSorted(context.reader(), facetField);
 
@@ -321,10 +317,6 @@ public abstract class TermGroupFacetCollector extends GroupFacetCollector {
 
     @Override
     protected void doSetNextReader(LeafReaderContext context) throws IOException {
-      if (segmentFacetCounts != null) {
-        segmentResults.add(createSegmentResult());
-      }
-
       groupFieldTermsIndex = DocValues.getSorted(context.reader(), groupField);
       facetFieldDocTermOrds = DocValues.getSortedSet(context.reader(), facetField);
       facetFieldNumTerms = (int) facetFieldDocTermOrds.getValueCount();
diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestIndexSearcher.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestIndexSearcher.java
index e46e73bb1aa..0c88359029b 100644
--- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestIndexSearcher.java
+++ b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestIndexSearcher.java
@@ -22,6 +22,7 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.BulkScorer;
 import org.apache.lucene.search.CollectionTerminatedException;
 import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.LeafCollector;
 import org.apache.lucene.search.Weight;
 
 /**
@@ -67,14 +68,16 @@ public class SuggestIndexSearcher extends IndexSearcher {
     for (LeafReaderContext context : getIndexReader().leaves()) {
       BulkScorer scorer = weight.bulkScorer(context);
       if (scorer != null) {
+        LeafCollector leafCollector = collector.getLeafCollector(context);
         try {
-          scorer.score(collector.getLeafCollector(context), context.reader().getLiveDocs());
+          scorer.score(leafCollector, context.reader().getLiveDocs());
         } catch (
             @SuppressWarnings("unused")
             CollectionTerminatedException e) {
           // collection was terminated prematurely
           // continue with the following leaf
         }
+        leafCollector.finish();
       }
     }
   }
diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocsCollector.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocsCollector.java
index e8f66c87b3b..3cfc8e2582d 100644
--- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocsCollector.java
+++ b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/TopSuggestDocsCollector.java
@@ -100,12 +100,19 @@ public class TopSuggestDocsCollector extends SimpleCollector {
   @Override
   protected void doSetNextReader(LeafReaderContext context) throws IOException {
     docBase = context.docBase;
+  }
+
+  @Override
+  public void finish() throws IOException {
     if (seenSurfaceForms != null) {
-      seenSurfaceForms.clear();
       // NOTE: this also clears the priorityQueue:
       for (SuggestScoreDoc hit : priorityQueue.getResults()) {
         pendingResults.add(hit);
       }
+
+      // Deduplicate all hits: we already dedup'd efficiently within each segment by
+      // truncating the FST top paths search, but across segments there may still be dups:
+      seenSurfaceForms.clear();
     }
   }
 
@@ -136,15 +143,7 @@ public class TopSuggestDocsCollector extends SimpleCollector {
     SuggestScoreDoc[] suggestScoreDocs;
 
     if (seenSurfaceForms != null) {
-      // NOTE: this also clears the priorityQueue:
-      for (SuggestScoreDoc hit : priorityQueue.getResults()) {
-        pendingResults.add(hit);
-      }
-
-      // Deduplicate all hits: we already dedup'd efficiently within each segment by
-      // truncating the FST top paths search, but across segments there may still be dups:
-      seenSurfaceForms.clear();
-
+      assert seenSurfaceForms.isEmpty();
       // TODO: we could use a priority queue here to make cost O(N * log(num)) instead of O(N *
       // log(N)), where N = O(num *
       // numSegments), but typically numSegments is smallish and num is smallish so this won't
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java
index cf2c2732614..af0df8cdc3f 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingCollector.java
@@ -30,11 +30,12 @@ class AssertingCollector extends FilterCollector {
   private boolean weightSet = false;
   private int maxDoc = -1;
   private int previousLeafMaxDoc = 0;
+  boolean hasFinishedCollectingPreviousLeaf = true;
 
   /** Wrap the given collector in order to add assertions. */
-  public static Collector wrap(Collector in) {
+  public static AssertingCollector wrap(Collector in) {
     if (in instanceof AssertingCollector) {
-      return in;
+      return (AssertingCollector) in;
     }
     return new AssertingCollector(in);
   }
@@ -49,7 +50,9 @@ class AssertingCollector extends FilterCollector {
     assert context.docBase >= previousLeafMaxDoc;
     previousLeafMaxDoc = context.docBase + context.reader().maxDoc();
 
+    assert hasFinishedCollectingPreviousLeaf;
     final LeafCollector in = super.getLeafCollector(context);
+    hasFinishedCollectingPreviousLeaf = false;
     final int docBase = context.docBase;
     return new AssertingLeafCollector(in, 0, DocIdSetIterator.NO_MORE_DOCS) {
       @Override
@@ -66,6 +69,12 @@ class AssertingCollector extends FilterCollector {
         super.collect(doc);
         maxDoc = docBase + doc;
       }
+
+      @Override
+      public void finish() throws IOException {
+        hasFinishedCollectingPreviousLeaf = true;
+        super.finish();
+      }
     };
   }
 
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingIndexSearcher.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingIndexSearcher.java
index 83abd788309..f5fa29b1494 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingIndexSearcher.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingIndexSearcher.java
@@ -75,7 +75,9 @@ public class AssertingIndexSearcher extends IndexSearcher {
   protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector)
       throws IOException {
     assert weight instanceof AssertingWeight;
-    super.search(leaves, weight, AssertingCollector.wrap(collector));
+    AssertingCollector assertingCollector = AssertingCollector.wrap(collector);
+    super.search(leaves, weight, assertingCollector);
+    assert assertingCollector.hasFinishedCollectingPreviousLeaf;
   }
 
   @Override
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingLeafCollector.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingLeafCollector.java
index 5c7801e5122..bcf3c4b1098 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingLeafCollector.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingLeafCollector.java
@@ -30,6 +30,7 @@ class AssertingLeafCollector extends FilterLeafCollector {
 
   private Scorable scorer;
   private int lastCollected = -1;
+  private boolean finishCalled;
 
   AssertingLeafCollector(LeafCollector collector, int min, int max) {
     super(collector);
@@ -57,4 +58,11 @@ class AssertingLeafCollector extends FilterLeafCollector {
   public DocIdSetIterator competitiveIterator() throws IOException {
     return in.competitiveIterator();
   }
+
+  @Override
+  public void finish() throws IOException {
+    assert finishCalled == false;
+    finishCalled = true;
+    super.finish();
+  }
 }