You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by gs...@apache.org on 2023/03/01 13:20:35 UTC

[lucene] branch main updated: Remove custom TermInSetQuery implementation in favor of extending MultiTermQuery (#12156)

This is an automated email from the ASF dual-hosted git repository.

gsmiller pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new 3809106602a Remove custom TermInSetQuery implementation in favor of extending MultiTermQuery (#12156)
3809106602a is described below

commit 3809106602a9675f4fd217b1090af4505d4ec2a7
Author: Greg Miller <gs...@gmail.com>
AuthorDate: Wed Mar 1 05:20:28 2023 -0800

    Remove custom TermInSetQuery implementation in favor of extending MultiTermQuery (#12156)
---
 lucene/CHANGES.txt                                 |   4 +
 ...AbstractMultiTermQueryConstantScoreWrapper.java |  92 +++++-
 .../org/apache/lucene/search/BooleanWeight.java    |   3 +-
 .../apache/lucene/search/DisjunctionMaxQuery.java  |   3 +-
 .../org/apache/lucene/search/MultiTermQuery.java   |   8 +
 .../org/apache/lucene/search/TermInSetQuery.java   | 341 +++++----------------
 .../apache/lucene/search/TestTermInSetQuery.java   |  11 +-
 .../org/apache/lucene/search/join/TermsQuery.java  |   5 +
 .../monitor/TestPresearcherMatchCollector.java     |   4 +-
 9 files changed, 194 insertions(+), 277 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 1bbdbb31d05..dc7ffdcb265 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -125,6 +125,10 @@ Improvements
   for multi-term queries with a FILTER rewrite (PrefixQuery, WildcardQuery, TermRangeQuery). This introduces better
   skipping support for common use-cases. (Adrien Grand, Greg Miller)
 
+* GITHUB#12156: TermInSetQuery now extends MultiTermQuery instead of providing its own custom implementation (which
+  was essentially a clone of MultiTermQuery#CONSTANT_SCORE_REWRITE). It uses the new CONSTANT_SCORE_BLENDED_REWRITE
+  by default, but can be overridden through the constructor. (Greg Miller)
+
 Optimizations
 ---------------------
 
diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractMultiTermQueryConstantScoreWrapper.java b/lucene/core/src/java/org/apache/lucene/search/AbstractMultiTermQueryConstantScoreWrapper.java
index 622264d5059..557746dd1e3 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractMultiTermQueryConstantScoreWrapper.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractMultiTermQueryConstantScoreWrapper.java
@@ -40,7 +40,7 @@ import org.apache.lucene.util.RamUsageEstimator;
 abstract class AbstractMultiTermQueryConstantScoreWrapper<Q extends MultiTermQuery> extends Query
     implements Accountable {
   // mtq that matches 16 terms or less will be executed as a regular disjunction
-  private static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16;
+  static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16;
 
   protected final Q query;
 
@@ -153,12 +153,9 @@ abstract class AbstractMultiTermQueryConstantScoreWrapper<Q extends MultiTermQue
         List<TermAndState> collectedTerms)
         throws IOException;
 
-    private WeightOrDocIdSetIterator rewrite(LeafReaderContext context) throws IOException {
-      final Terms terms = context.reader().terms(q.field);
-      if (terms == null) {
-        // field does not exist
-        return null;
-      }
+    private WeightOrDocIdSetIterator rewrite(LeafReaderContext context, Terms terms)
+        throws IOException {
+      assert terms != null;
 
       final int fieldDocCount = terms.getDocCount();
       final TermsEnum termsEnum = q.getTermsEnum(terms);
@@ -216,7 +213,11 @@ abstract class AbstractMultiTermQueryConstantScoreWrapper<Q extends MultiTermQue
 
     @Override
     public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
-      final WeightOrDocIdSetIterator weightOrIterator = rewrite(context);
+      final Terms terms = context.reader().terms(q.getField());
+      if (terms == null) {
+        return null;
+      }
+      final WeightOrDocIdSetIterator weightOrIterator = rewrite(context, terms);
       if (weightOrIterator == null) {
         return null;
       } else if (weightOrIterator.weight != null) {
@@ -232,14 +233,11 @@ abstract class AbstractMultiTermQueryConstantScoreWrapper<Q extends MultiTermQue
 
     @Override
     public Scorer scorer(LeafReaderContext context) throws IOException {
-      final WeightOrDocIdSetIterator weightOrIterator = rewrite(context);
-      if (weightOrIterator == null) {
+      final ScorerSupplier scorerSupplier = scorerSupplier(context);
+      if (scorerSupplier == null) {
         return null;
-      } else if (weightOrIterator.weight != null) {
-        return weightOrIterator.weight.scorer(context);
-      } else {
-        return scorerForIterator(weightOrIterator.iterator);
       }
+      return scorerSupplier.get(Long.MAX_VALUE);
     }
 
     @Override
@@ -255,6 +253,72 @@ abstract class AbstractMultiTermQueryConstantScoreWrapper<Q extends MultiTermQue
                   context, doc, q, q.field, q.getTermsEnum(terms)));
     }
 
+    @Override
+    public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
+      final Terms terms = context.reader().terms(q.getField());
+      if (terms == null) {
+        return null;
+      }
+
+      final long cost = estimateCost(terms, q.getTermsCount());
+
+      final Weight weight = this;
+      return new ScorerSupplier() {
+        @Override
+        public Scorer get(long leadCost) throws IOException {
+          WeightOrDocIdSetIterator weightOrIterator = rewrite(context, terms);
+          final Scorer scorer;
+          if (weightOrIterator == null) {
+            scorer = null;
+          } else if (weightOrIterator.weight != null) {
+            scorer = weightOrIterator.weight.scorer(context);
+          } else {
+            scorer = scorerForIterator(weightOrIterator.iterator);
+          }
+
+          // It's against the API contract to return a null scorer from a non-null ScoreSupplier.
+          // So if our ScoreSupplier was non-null (i.e., thought there might be hits) but we now
+          // find that there are actually no hits, we need to return an empty Scorer as opposed
+          // to null:
+          return Objects.requireNonNullElseGet(
+              scorer,
+              () -> new ConstantScoreScorer(weight, score(), scoreMode, DocIdSetIterator.empty()));
+        }
+
+        @Override
+        public long cost() {
+          return cost;
+        }
+      };
+    }
+
+    private static long estimateCost(Terms terms, long queryTermsCount) throws IOException {
+      // Estimate the cost. If the MTQ can provide its term count, we can do a better job
+      // estimating.
+      // Cost estimation reasoning is:
+      // 1. If we don't know how many query terms there are, we assume that every term could be
+      //    in the MTQ and estimate the work as the total docs across all terms.
+      // 2. If we know how many query terms there are...
+      //    2a. Assume every query term matches at least one document (queryTermsCount).
+      //    2b. Determine the total number of docs beyond the first one for each term.
+      //        That count provides a ceiling on the number of extra docs that could match beyond
+      //        that first one. (We omit the first since it's already been counted in 2a).
+      // See: LUCENE-10207
+      long cost;
+      if (queryTermsCount == -1) {
+        cost = terms.getSumDocFreq();
+      } else {
+        long potentialExtraCost = terms.getSumDocFreq();
+        final long indexedTermCount = terms.size();
+        if (indexedTermCount != -1) {
+          potentialExtraCost -= indexedTermCount;
+        }
+        cost = queryTermsCount + potentialExtraCost;
+      }
+
+      return cost;
+    }
+
     @Override
     public boolean isCacheable(LeafReaderContext ctx) {
       return true;
diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java b/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java
index a9aee1185d1..ce417e3cb79 100644
--- a/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java
+++ b/lucene/core/src/java/org/apache/lucene/search/BooleanWeight.java
@@ -513,7 +513,8 @@ final class BooleanWeight extends Weight {
 
   @Override
   public boolean isCacheable(LeafReaderContext ctx) {
-    if (query.clauses().size() > TermInSetQuery.BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) {
+    if (query.clauses().size()
+        > AbstractMultiTermQueryConstantScoreWrapper.BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) {
       // Disallow caching large boolean queries to not encourage users
       // to build large boolean queries as a workaround to the fact that
       // we disallow caching large TermInSetQueries.
diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java
index 29772154b16..1ab4f3b5081 100644
--- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java
@@ -151,7 +151,8 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
 
     @Override
     public boolean isCacheable(LeafReaderContext ctx) {
-      if (weights.size() > TermInSetQuery.BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) {
+      if (weights.size()
+          > AbstractMultiTermQueryConstantScoreWrapper.BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) {
         // Disallow caching large dismax queries to not encourage users
         // to build large dismax queries as a workaround to the fact that
         // we disallow caching large TermInSetQueries.
diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java b/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java
index 28e2564e496..72c1ea1ca91 100644
--- a/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java
@@ -295,6 +295,14 @@ public abstract class MultiTermQuery extends Query {
     return getTermsEnum(terms, new AttributeSource());
   }
 
+  /**
+   * Return the number of unique terms contained in this query, if known up-front. If not known, -1
+   * will be returned.
+   */
+  public long getTermsCount() throws IOException {
+    return -1;
+  }
+
   /**
    * To rewrite to a simpler form, instead return a simpler enum from {@link #getTermsEnum(Terms,
    * AttributeSource)}. For example, to rewrite to a single term, return a {@link SingleTermsEnum}
diff --git a/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java
index ef254c894ec..9d482369975 100644
--- a/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java
@@ -22,24 +22,18 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
-import java.util.Objects;
 import java.util.SortedSet;
-import org.apache.lucene.index.LeafReader;
-import org.apache.lucene.index.LeafReaderContext;
-import org.apache.lucene.index.PostingsEnum;
+import org.apache.lucene.index.FilteredTermsEnum;
 import org.apache.lucene.index.PrefixCodedTerms;
 import org.apache.lucene.index.PrefixCodedTerms.TermIterator;
 import org.apache.lucene.index.Term;
-import org.apache.lucene.index.TermState;
-import org.apache.lucene.index.TermStates;
 import org.apache.lucene.index.Terms;
 import org.apache.lucene.index.TermsEnum;
-import org.apache.lucene.search.BooleanClause.Occur;
 import org.apache.lucene.util.Accountable;
 import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.AttributeSource;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.BytesRefBuilder;
-import org.apache.lucene.util.DocIdSetBuilder;
 import org.apache.lucene.util.RamUsageEstimator;
 import org.apache.lucene.util.automaton.Automata;
 import org.apache.lucene.util.automaton.Automaton;
@@ -48,8 +42,8 @@ import org.apache.lucene.util.automaton.CompiledAutomaton;
 import org.apache.lucene.util.automaton.Operations;
 
 /**
- * Specialization for a disjunction over many terms that behaves like a {@link ConstantScoreQuery}
- * over a {@link BooleanQuery} containing only {@link
+ * Specialization for a disjunction over many terms that, by default, behaves like a {@link
+ * ConstantScoreQuery} over a {@link BooleanQuery} containing only {@link
  * org.apache.lucene.search.BooleanClause.Occur#SHOULD} clauses.
  *
  * <p>For instance in the following example, both {@code q1} and {@code q2} would yield the same
@@ -64,30 +58,62 @@ import org.apache.lucene.util.automaton.Operations;
  * Query q2 = new ConstantScoreQuery(bq);
  * </pre>
  *
- * <p>When there are few terms, this query executes like a regular disjunction. However, when there
- * are many terms, instead of merging iterators on the fly, it will populate a bit set with matching
- * docs and return a {@link Scorer} over this bit set.
+ * <p>Unless a custom {@link MultiTermQuery.RewriteMethod} is provided, this query executes like a
+ * regular disjunction where there are few terms. However, when there are many terms, instead of
+ * merging iterators on the fly, it will populate a bit set with matching docs for the least-costly
+ * terms and maintain a size-limited set of more costly iterators that are merged on the fly. For
+ * more details, see {@link MultiTermQuery#CONSTANT_SCORE_BLENDED_REWRITE}.
+ *
+ * <p>Users may also provide a custom {@link MultiTermQuery.RewriteMethod} to define different
+ * execution behavior, such as relying on doc values (see: {@link DocValuesRewriteMethod}), or if
+ * scores are required (see: {@link MultiTermQuery#SCORING_BOOLEAN_REWRITE}). See {@link
+ * MultiTermQuery} documentation for more rewrite options.
  *
  * <p>NOTE: This query produces scores that are equal to its boost
  */
-public class TermInSetQuery extends Query implements Accountable {
+public class TermInSetQuery extends MultiTermQuery implements Accountable {
 
   private static final long BASE_RAM_BYTES_USED =
       RamUsageEstimator.shallowSizeOfInstance(TermInSetQuery.class);
-  // Same threshold as MultiTermQueryConstantScoreWrapper
-  static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16;
 
   private final String field;
   private final PrefixCodedTerms termData;
   private final int termDataHashCode; // cached hashcode of termData
 
-  /** Creates a new {@link TermInSetQuery} from the given collection of terms. */
   public TermInSetQuery(String field, Collection<BytesRef> terms) {
+    this(field, packTerms(field, terms));
+  }
+
+  public TermInSetQuery(String field, BytesRef... terms) {
+    this(field, packTerms(field, Arrays.asList(terms)));
+  }
+
+  /** Creates a new {@link TermInSetQuery} from the given collection of terms. */
+  public TermInSetQuery(RewriteMethod rewriteMethod, String field, Collection<BytesRef> terms) {
+    super(field, rewriteMethod);
+    this.field = field;
+    this.termData = packTerms(field, terms);
+    termDataHashCode = termData.hashCode();
+  }
+
+  /** Creates a new {@link TermInSetQuery} from the given array of terms. */
+  public TermInSetQuery(RewriteMethod rewriteMethod, String field, BytesRef... terms) {
+    this(rewriteMethod, field, Arrays.asList(terms));
+  }
+
+  private TermInSetQuery(String field, PrefixCodedTerms termData) {
+    super(field, MultiTermQuery.CONSTANT_SCORE_BLENDED_REWRITE);
+    this.field = field;
+    this.termData = termData;
+    termDataHashCode = termData.hashCode();
+  }
+
+  private static PrefixCodedTerms packTerms(String field, Collection<BytesRef> terms) {
     BytesRef[] sortedTerms = terms.toArray(new BytesRef[0]);
     // already sorted if we are a SortedSet with natural order
     boolean sorted =
         terms instanceof SortedSet && ((SortedSet<BytesRef>) terms).comparator() == null;
-    if (!sorted) {
+    if (sorted == false) {
       ArrayUtil.timSort(sortedTerms);
     }
     PrefixCodedTerms.Builder builder = new PrefixCodedTerms.Builder();
@@ -101,29 +127,13 @@ public class TermInSetQuery extends Query implements Accountable {
       builder.add(field, term);
       previous.copyBytes(term);
     }
-    this.field = field;
-    termData = builder.finish();
-    termDataHashCode = termData.hashCode();
-  }
 
-  /** Creates a new {@link TermInSetQuery} from the given array of terms. */
-  public TermInSetQuery(String field, BytesRef... terms) {
-    this(field, Arrays.asList(terms));
+    return builder.finish();
   }
 
   @Override
-  public Query rewrite(IndexSearcher indexSearcher) throws IOException {
-    final int threshold =
-        Math.min(BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD, IndexSearcher.getMaxClauseCount());
-    if (termData.size() <= threshold) {
-      BooleanQuery.Builder bq = new BooleanQuery.Builder();
-      TermIterator iterator = termData.iterator();
-      for (BytesRef term = iterator.next(); term != null; term = iterator.next()) {
-        bq.add(new TermQuery(new Term(iterator.field(), BytesRef.deepCopyOf(term))), Occur.SHOULD);
-      }
-      return new ConstantScoreQuery(bq.build());
-    }
-    return super.rewrite(indexSearcher);
+  public long getTermsCount() throws IOException {
+    return termData.size();
   }
 
   @Override
@@ -203,233 +213,52 @@ public class TermInSetQuery extends Query implements Accountable {
     return Collections.emptyList();
   }
 
-  private static class TermAndState {
-    final String field;
-    final TermsEnum termsEnum;
-    final BytesRef term;
-    final TermState state;
-    final int docFreq;
-    final long totalTermFreq;
-
-    TermAndState(String field, TermsEnum termsEnum) throws IOException {
-      this.field = field;
-      this.termsEnum = termsEnum;
-      this.term = BytesRef.deepCopyOf(termsEnum.term());
-      this.state = termsEnum.termState();
-      this.docFreq = termsEnum.docFreq();
-      this.totalTermFreq = termsEnum.totalTermFreq();
-    }
+  @Override
+  protected TermsEnum getTermsEnum(Terms terms, AttributeSource atts) throws IOException {
+    return new SetEnum(terms.iterator());
   }
 
-  private static class WeightOrDocIdSet {
-    final Weight weight;
-    final DocIdSet set;
-
-    WeightOrDocIdSet(Weight weight) {
-      this.weight = Objects.requireNonNull(weight);
-      this.set = null;
-    }
-
-    WeightOrDocIdSet(DocIdSet bitset) {
-      this.set = bitset;
-      this.weight = null;
+  /**
+   * Like a baby {@link org.apache.lucene.index.AutomatonTermsEnum}, ping-pong intersects the terms
+   * dict against our encoded query terms.
+   */
+  private class SetEnum extends FilteredTermsEnum {
+    private final TermIterator iterator;
+    private BytesRef seekTerm;
+
+    SetEnum(TermsEnum termsEnum) {
+      super(termsEnum);
+      iterator = termData.iterator();
+      seekTerm = iterator.next();
     }
-  }
-
-  @Override
-  public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
-      throws IOException {
-    return new ConstantScoreWeight(this, boost) {
 
-      @Override
-      public Matches matches(LeafReaderContext context, int doc) throws IOException {
-        Terms terms = Terms.getTerms(context.reader(), field);
-        if (terms.hasPositions() == false) {
-          return super.matches(context, doc);
-        }
-        return MatchesUtils.forField(
-            field,
-            () ->
-                DisjunctionMatchesIterator.fromTermsEnum(
-                    context, doc, getQuery(), field, termData.iterator()));
+    @Override
+    protected AcceptStatus accept(BytesRef term) throws IOException {
+      // next() our iterator until it is >= the incoming term
+      // if it matches exactly, it's a hit, otherwise it's a miss
+      int cmp = 0;
+      while (seekTerm != null && (cmp = seekTerm.compareTo(term)) < 0) {
+        seekTerm = iterator.next();
       }
-
-      /**
-       * On the given leaf context, try to either rewrite to a disjunction if there are few matching
-       * terms, or build a bitset containing matching docs.
-       */
-      private WeightOrDocIdSet rewrite(LeafReaderContext context) throws IOException {
-        final LeafReader reader = context.reader();
-
-        Terms terms = reader.terms(field);
-        if (terms == null) {
-          return null;
-        }
-        final int fieldDocCount = terms.getDocCount();
-        TermsEnum termsEnum = terms.iterator();
-        PostingsEnum docs = null;
-        TermIterator iterator = termData.iterator();
-
-        // We will first try to collect up to 'threshold' terms into 'matchingTerms'
-        // if there are too many terms, we will fall back to building the 'builder'
-        final int threshold =
-            Math.min(BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD, IndexSearcher.getMaxClauseCount());
-        assert termData.size() > threshold : "Query should have been rewritten";
-        List<TermAndState> matchingTerms = new ArrayList<>(threshold);
-        DocIdSetBuilder builder = null;
-
-        for (BytesRef term = iterator.next(); term != null; term = iterator.next()) {
-          assert field.equals(iterator.field());
-          if (termsEnum.seekExact(term)) {
-            // If a term contains all docs with a value for the specified field (likely rare),
-            // we can discard the other terms and just use the dense term's postings:
-            int docFreq = termsEnum.docFreq();
-            if (fieldDocCount == docFreq) {
-              TermStates termStates = new TermStates(searcher.getTopReaderContext());
-              termStates.register(
-                  termsEnum.termState(), context.ord, docFreq, termsEnum.totalTermFreq());
-              Query q =
-                  new ConstantScoreQuery(
-                      new TermQuery(new Term(field, termsEnum.term()), termStates));
-              Weight weight = searcher.rewrite(q).createWeight(searcher, scoreMode, score());
-              return new WeightOrDocIdSet(weight);
-            }
-
-            if (matchingTerms == null) {
-              docs = termsEnum.postings(docs, PostingsEnum.NONE);
-              builder.add(docs);
-            } else if (matchingTerms.size() < threshold) {
-              matchingTerms.add(new TermAndState(field, termsEnum));
-            } else {
-              assert matchingTerms.size() == threshold;
-              builder = new DocIdSetBuilder(reader.maxDoc(), terms);
-              docs = termsEnum.postings(docs, PostingsEnum.NONE);
-              builder.add(docs);
-              for (TermAndState t : matchingTerms) {
-                t.termsEnum.seekExact(t.term, t.state);
-                docs = t.termsEnum.postings(docs, PostingsEnum.NONE);
-                builder.add(docs);
-              }
-              matchingTerms = null;
-            }
-          }
-        }
-
-        if (matchingTerms != null) {
-          assert builder == null;
-          BooleanQuery.Builder bq = new BooleanQuery.Builder();
-          for (TermAndState t : matchingTerms) {
-            final TermStates termStates = new TermStates(searcher.getTopReaderContext());
-            termStates.register(t.state, context.ord, t.docFreq, t.totalTermFreq);
-            bq.add(new TermQuery(new Term(t.field, t.term), termStates), Occur.SHOULD);
-          }
-          Query q = new ConstantScoreQuery(bq.build());
-          final Weight weight = searcher.rewrite(q).createWeight(searcher, scoreMode, score());
-          return new WeightOrDocIdSet(weight);
-        } else {
-          assert builder != null;
-          return new WeightOrDocIdSet(builder.build());
-        }
-      }
-
-      private Scorer scorer(DocIdSet set) throws IOException {
-        if (set == null) {
-          return null;
-        }
-        final DocIdSetIterator disi = set.iterator();
-        if (disi == null) {
-          return null;
-        }
-        return new ConstantScoreScorer(this, score(), scoreMode, disi);
-      }
-
-      @Override
-      public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
-        final WeightOrDocIdSet weightOrBitSet = rewrite(context);
-        if (weightOrBitSet == null) {
-          return null;
-        } else if (weightOrBitSet.weight != null) {
-          return weightOrBitSet.weight.bulkScorer(context);
-        } else {
-          final Scorer scorer = scorer(weightOrBitSet.set);
-          if (scorer == null) {
-            return null;
-          }
-          return new DefaultBulkScorer(scorer);
-        }
-      }
-
-      @Override
-      public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
-        Terms indexTerms = context.reader().terms(field);
-        if (indexTerms == null) {
-          return null;
-        }
-
-        // Cost estimation reasoning is:
-        //  1. Assume every query term matches at least one document (queryTermsCount).
-        //  2. Determine the total number of docs beyond the first one for each term.
-        //     That count provides a ceiling on the number of extra docs that could match beyond
-        //     that first one. (We omit the first since it's already been counted in #1).
-        // This approach still provides correct worst-case cost in general, but provides tighter
-        // estimates for primary-key-like fields. See: LUCENE-10207
-
-        // TODO: This cost estimation may grossly overestimate since we have no index statistics
-        // for the specific query terms. While it's nice to avoid the cost of intersecting the
-        // query terms with the index, it could be beneficial to do that work and get better
-        // cost estimates.
-        final long cost;
-        final long queryTermsCount = termData.size();
-        long potentialExtraCost = indexTerms.getSumDocFreq();
-        final long indexedTermCount = indexTerms.size();
-        if (indexedTermCount != -1) {
-          potentialExtraCost -= indexedTermCount;
-        }
-        cost = queryTermsCount + potentialExtraCost;
-
-        final Weight weight = this;
-        return new ScorerSupplier() {
-          @Override
-          public Scorer get(long leadCost) throws IOException {
-            WeightOrDocIdSet weightOrDocIdSet = rewrite(context);
-            final Scorer scorer;
-            if (weightOrDocIdSet == null) {
-              scorer = null;
-            } else if (weightOrDocIdSet.weight != null) {
-              scorer = weightOrDocIdSet.weight.scorer(context);
-            } else {
-              scorer = scorer(weightOrDocIdSet.set);
-            }
-
-            return Objects.requireNonNullElseGet(
-                scorer,
-                () ->
-                    new ConstantScoreScorer(weight, score(), scoreMode, DocIdSetIterator.empty()));
-          }
-
-          @Override
-          public long cost() {
-            return cost;
-          }
-        };
+      if (seekTerm == null) {
+        return AcceptStatus.END;
+      } else if (cmp == 0) {
+        return AcceptStatus.YES_AND_SEEK;
+      } else {
+        return AcceptStatus.NO_AND_SEEK;
       }
+    }
 
-      @Override
-      public Scorer scorer(LeafReaderContext context) throws IOException {
-        final ScorerSupplier supplier = scorerSupplier(context);
-        if (supplier == null) {
-          return null;
-        }
-        return supplier.get(Long.MAX_VALUE);
+    @Override
+    protected BytesRef nextSeekTerm(BytesRef currentTerm) throws IOException {
+      // next() our iterator until it is > the currentTerm, must always make progress.
+      if (currentTerm == null) {
+        return seekTerm;
       }
-
-      @Override
-      public boolean isCacheable(LeafReaderContext ctx) {
-        // Only cache instances that have a reasonable size. Otherwise it might cause memory issues
-        // with the query cache if most memory ends up being spent on queries rather than doc id
-        // sets.
-        return ramBytesUsed() <= RamUsageEstimator.QUERY_DEFAULT_RAM_BYTES_USED;
+      while (seekTerm != null && seekTerm.compareTo(currentTerm) <= 0) {
+        seekTerm = iterator.next();
       }
-    };
+      return seekTerm;
+    }
   }
 }
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTermInSetQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestTermInSetQuery.java
index bac984b4cdc..a62d7f8fc4d 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestTermInSetQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestTermInSetQuery.java
@@ -59,10 +59,12 @@ public class TestTermInSetQuery extends LuceneTestCase {
     BytesRef denseTerm = new BytesRef(TestUtil.randomAnalysisString(random(), 10, true));
 
     Set<BytesRef> randomTerms = new HashSet<>();
-    while (randomTerms.size() < TermInSetQuery.BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) {
+    while (randomTerms.size()
+        < AbstractMultiTermQueryConstantScoreWrapper.BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) {
       randomTerms.add(new BytesRef(TestUtil.randomAnalysisString(random(), 10, true)));
     }
-    assert randomTerms.size() == TermInSetQuery.BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD;
+    assert randomTerms.size()
+        == AbstractMultiTermQueryConstantScoreWrapper.BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD;
     BytesRef[] otherTerms = new BytesRef[randomTerms.size()];
     int idx = 0;
     for (BytesRef term : randomTerms) {
@@ -325,7 +327,10 @@ public class TestTermInSetQuery extends LuceneTestCase {
     final List<BytesRef> terms = new ArrayList<>();
     // enough terms to avoid the rewrite
     final int numTerms =
-        TestUtil.nextInt(random(), TermInSetQuery.BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD + 1, 100);
+        TestUtil.nextInt(
+            random(),
+            AbstractMultiTermQueryConstantScoreWrapper.BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD + 1,
+            100);
     for (int i = 0; i < numTerms; ++i) {
       final BytesRef term = newBytesRef(RandomStrings.randomUnicodeOfCodepointLength(random(), 10));
       terms.add(term);
diff --git a/lucene/join/src/java/org/apache/lucene/search/join/TermsQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/TermsQuery.java
index 9f4db1cb5b0..729feb19b9c 100644
--- a/lucene/join/src/java/org/apache/lucene/search/join/TermsQuery.java
+++ b/lucene/join/src/java/org/apache/lucene/search/join/TermsQuery.java
@@ -94,6 +94,11 @@ class TermsQuery extends MultiTermQuery implements Accountable {
     return new SeekingTermSetTermsEnum(terms.iterator(), this.terms, ords);
   }
 
+  @Override
+  public long getTermsCount() throws IOException {
+    return terms.size();
+  }
+
   @Override
   public String toString(String string) {
     return "TermsQuery{" + "field=" + field + "fromQuery=" + fromQuery.toString(field) + '}';
diff --git a/lucene/monitor/src/test/org/apache/lucene/monitor/TestPresearcherMatchCollector.java b/lucene/monitor/src/test/org/apache/lucene/monitor/TestPresearcherMatchCollector.java
index 47d24d85220..608df67de29 100644
--- a/lucene/monitor/src/test/org/apache/lucene/monitor/TestPresearcherMatchCollector.java
+++ b/lucene/monitor/src/test/org/apache/lucene/monitor/TestPresearcherMatchCollector.java
@@ -46,8 +46,8 @@ public class TestPresearcherMatchCollector extends MonitorTestBase {
 
       assertNotNull(matches.match("2", 0));
       String pm = matches.match("2", 0).presearcherMatches;
-      assertThat(pm, containsString("field:foo"));
-      assertThat(pm, containsString("f2:quuz"));
+      assertThat(pm, containsString("field:(foo test)"));
+      assertThat(pm, containsString("f2:(quuz)"));
 
       assertNotNull(matches.match("3", 0));
       assertEquals(" field:foo", matches.match("3", 0).presearcherMatches);