You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ju...@apache.org on 2022/07/29 19:49:50 UTC

[lucene] branch branch_9x updated (6366cf2e7ad -> 33d5ab96f26)

This is an automated email from the ASF dual-hosted git repository.

julietibs pushed a change to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


    from 6366cf2e7ad LUCENE-10633: Fix handling of missing values in reverse sorts.
     new 1559de836ce LUCENE-10663: Fix KnnVectorQuery explain with multiple segments (#1050)
     new 2cb0e260755 LUCENE-10504: KnnGraphTester to use KnnVectorQuery (#796)
     new 33d5ab96f26 LUCENE-10559: Add Prefilter Option to KnnGraphTester (#932)

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 lucene/CHANGES.txt                                 |   4 +-
 .../org/apache/lucene/search/KnnVectorQuery.java   |   2 +-
 .../apache/lucene/search/TestKnnVectorQuery.java   |  28 ++++
 .../apache/lucene/util/hnsw/KnnGraphTester.java    | 182 +++++++++++++++++----
 4 files changed, 177 insertions(+), 39 deletions(-)


[lucene] 02/03: LUCENE-10504: KnnGraphTester to use KnnVectorQuery (#796)

Posted by ju...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

julietibs pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git

commit 2cb0e26075559e4ce38d2fa9765bcccaa187ce0d
Author: Michael Sokolov <so...@falutin.net>
AuthorDate: Wed May 4 18:22:48 2022 -0400

    LUCENE-10504: KnnGraphTester to use KnnVectorQuery (#796)
    
    * LUCENE-10504: KnnGraphTester to use KnnVectorQuery
---
 .../apache/lucene/util/hnsw/KnnGraphTester.java    | 52 +++++++++++-----------
 1 file changed, 26 insertions(+), 26 deletions(-)

diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index 0ae7dea5105..1e9fdbbf126 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -32,8 +32,10 @@ import java.nio.channels.FileChannel;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Paths;
+import java.nio.file.attribute.FileTime;
 import java.util.HashSet;
 import java.util.Locale;
+import java.util.Objects;
 import java.util.Set;
 import org.apache.lucene.codecs.KnnVectorsFormat;
 import org.apache.lucene.codecs.KnnVectorsReader;
@@ -47,7 +49,6 @@ import org.apache.lucene.document.KnnVectorField;
 import org.apache.lucene.document.StoredField;
 import org.apache.lucene.index.CodecReader;
 import org.apache.lucene.index.DirectoryReader;
-import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.LeafReader;
@@ -55,11 +56,12 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.RandomAccessVectorValues;
 import org.apache.lucene.index.RandomAccessVectorValuesProducer;
 import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnVectorQuery;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.FSDirectory;
-import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.IntroSorter;
 import org.apache.lucene.util.PrintStreamInfoStream;
@@ -79,7 +81,6 @@ public class KnnGraphTester {
   private int numDocs;
   private int dim;
   private int topK;
-  private int warmCount;
   private int numIters;
   private int fanout;
   private Path indexPath;
@@ -98,7 +99,6 @@ public class KnnGraphTester {
     numIters = 1000;
     dim = 256;
     topK = 100;
-    warmCount = 1000;
     fanout = topK;
     similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
   }
@@ -178,9 +178,6 @@ public class KnnGraphTester {
         case "-out":
           outputPath = Paths.get(args[++iarg]);
           break;
-        case "-warm":
-          warmCount = Integer.parseInt(args[++iarg]);
-          break;
         case "-docs":
           docVectorsPath = Paths.get(args[++iarg]);
           break;
@@ -349,8 +346,9 @@ public class KnnGraphTester {
     TopDocs[] results = new TopDocs[numIters];
     long elapsed, totalCpuTime, totalVisited = 0;
     try (FileChannel q = FileChannel.open(queryPath)) {
+      int bufferSize = numIters * dim * Float.BYTES;
       FloatBuffer targets =
-          q.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
+          q.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize)
               .order(ByteOrder.LITTLE_ENDIAN)
               .asFloatBuffer();
       float[] target = new float[dim];
@@ -362,18 +360,19 @@ public class KnnGraphTester {
       long cpuTimeStartNs;
       try (Directory dir = FSDirectory.open(indexPath);
           DirectoryReader reader = DirectoryReader.open(dir)) {
+        IndexSearcher searcher = new IndexSearcher(reader);
         numDocs = reader.maxDoc();
-        for (int i = 0; i < warmCount; i++) {
+        for (int i = 0; i < numIters; i++) {
           // warm up
           targets.get(target);
-          results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
+          doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
         }
         targets.position(0);
         start = System.nanoTime();
         cpuTimeStartNs = bean.getCurrentThreadCpuTime();
         for (int i = 0; i < numIters; i++) {
           targets.get(target);
-          results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
+          results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
         }
         totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
         elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms
@@ -430,19 +429,9 @@ public class KnnGraphTester {
     }
   }
 
-  private static TopDocs doKnnSearch(
-      IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException {
-    TopDocs[] results = new TopDocs[reader.leaves().size()];
-    for (LeafReaderContext ctx : reader.leaves()) {
-      Bits liveDocs = ctx.reader().getLiveDocs();
-      results[ctx.ord] =
-          ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs, Integer.MAX_VALUE);
-      int docBase = ctx.docBase;
-      for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
-        scoreDoc.doc += docBase;
-      }
-    }
-    return TopDocs.merge(k, results);
+  private static TopDocs doKnnVectorQuery(
+      IndexSearcher searcher, String field, float[] vector, int k, int fanout) throws IOException {
+    return searcher.search(new KnnVectorQuery(field, vector, k + fanout), k);
   }
 
   private float checkResults(TopDocs[] results, int[][] nn) {
@@ -487,9 +476,10 @@ public class KnnGraphTester {
 
   private int[][] getNN(Path docPath, Path queryPath) throws IOException {
     // look in working directory for cached nn file
-    String nnFileName = "nn-" + numDocs + "-" + numIters + "-" + topK + "-" + dim + ".bin";
+    String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK), 36);
+    String nnFileName = "nn-" + hash + ".bin";
     Path nnPath = Paths.get(nnFileName);
-    if (Files.exists(nnPath)) {
+    if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath)) {
       return readNN(nnPath);
     } else {
       int[][] nn = computeNN(docPath, queryPath);
@@ -498,6 +488,16 @@ public class KnnGraphTester {
     }
   }
 
+  private boolean isNewer(Path path, Path... others) throws IOException {
+    FileTime modified = Files.getLastModifiedTime(path);
+    for (Path other : others) {
+      if (Files.getLastModifiedTime(other).compareTo(modified) >= 0) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   private int[][] readNN(Path nnPath) throws IOException {
     int[][] result = new int[numIters][];
     try (FileChannel in = FileChannel.open(nnPath)) {


[lucene] 03/03: LUCENE-10559: Add Prefilter Option to KnnGraphTester (#932)

Posted by ju...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

julietibs pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git

commit 33d5ab96f266447b228833baee3489b12bdc3a68
Author: Kaival Parikh <46...@users.noreply.github.com>
AuthorDate: Fri Jul 29 23:51:34 2022 +0530

    LUCENE-10559: Add Prefilter Option to KnnGraphTester (#932)
    
    Added a `prefilter` and `filterSelectivity` argument to KnnGraphTester to be
    able to compare pre and post-filtering benchmarks.
    
    `filterSelectivity` expresses the selectivity of a filter as proportion of
    passing docs that are randomly selected. We store these in a FixedBitSet and
    use this to calculate true KNN as well as in HNSW search.
    
    In case of post-filter, we over-select results as `topK / filterSelectivity` to
    get final hits close to actual requested `topK`. For pre-filter, we wrap the
    FixedBitSet in a query and pass it as prefilter argument to KnnVectorQuery.
---
 lucene/CHANGES.txt                                 |   2 +-
 .../apache/lucene/util/hnsw/KnnGraphTester.java    | 140 ++++++++++++++++++---
 2 files changed, 126 insertions(+), 16 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 09f5ef4832a..a5937d86dbd 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -34,7 +34,7 @@ Bug Fixes
 
 Other
 ---------------------
-(No changes)
+* LUCENE-10559: Add Prefilter Option to KnnGraphTester (Kaival Parikh)
 
 ======================== Lucene 9.3.0 =======================
 
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index 1e9fdbbf126..140e8f9df69 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -33,6 +33,7 @@ import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.nio.file.attribute.FileTime;
+import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Locale;
 import java.util.Objects;
@@ -56,13 +57,22 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.RandomAccessVectorValues;
 import org.apache.lucene.index.RandomAccessVectorValuesProducer;
 import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.ConstantScoreScorer;
+import org.apache.lucene.search.ConstantScoreWeight;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.KnnVectorQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.QueryVisitor;
 import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Scorer;
 import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.Weight;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.FSDirectory;
+import org.apache.lucene.util.BitSetIterator;
 import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.FixedBitSet;
 import org.apache.lucene.util.IntroSorter;
 import org.apache.lucene.util.PrintStreamInfoStream;
 import org.apache.lucene.util.SuppressForbidden;
@@ -91,8 +101,10 @@ public class KnnGraphTester {
   private int beamWidth;
   private int maxConn;
   private VectorSimilarityFunction similarityFunction;
+  private FixedBitSet matchDocs;
+  private float selectivity;
+  private boolean prefilter;
 
-  @SuppressForbidden(reason = "uses Random()")
   private KnnGraphTester() {
     // set defaults
     numDocs = 1000;
@@ -101,6 +113,8 @@ public class KnnGraphTester {
     topK = 100;
     fanout = topK;
     similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
+    selectivity = 1f;
+    prefilter = false;
   }
 
   public static void main(String... args) throws Exception {
@@ -192,6 +206,18 @@ public class KnnGraphTester {
         case "-forceMerge":
           forceMerge = true;
           break;
+        case "-prefilter":
+          prefilter = true;
+          break;
+        case "-filterSelectivity":
+          if (iarg == args.length - 1) {
+            throw new IllegalArgumentException("-filterSelectivity requires a following float");
+          }
+          selectivity = Float.parseFloat(args[++iarg]);
+          if (selectivity <= 0 || selectivity >= 1) {
+            throw new IllegalArgumentException("-filterSelectivity must be between 0 and 1");
+          }
+          break;
         case "-quiet":
           quiet = true;
           break;
@@ -203,6 +229,9 @@ public class KnnGraphTester {
     if (operation == null && reindex == false) {
       usage();
     }
+    if (prefilter == true && selectivity == 1f) {
+      throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
+    }
     indexPath = Paths.get(formatIndexPath(docVectorsPath));
     if (reindex) {
       if (docVectorsPath == null) {
@@ -219,6 +248,7 @@ public class KnnGraphTester {
           if (docVectorsPath == null) {
             throw new IllegalArgumentException("missing -docs arg");
           }
+          matchDocs = generateRandomBitSet(numDocs, selectivity);
           if (outputPath != null) {
             testSearch(indexPath, queryPath, outputPath, null);
           } else {
@@ -362,17 +392,33 @@ public class KnnGraphTester {
           DirectoryReader reader = DirectoryReader.open(dir)) {
         IndexSearcher searcher = new IndexSearcher(reader);
         numDocs = reader.maxDoc();
+        Query bitSetQuery = new BitSetQuery(matchDocs);
         for (int i = 0; i < numIters; i++) {
           // warm up
           targets.get(target);
-          doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
+          if (prefilter) {
+            doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
+          } else {
+            doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
+          }
         }
         targets.position(0);
         start = System.nanoTime();
         cpuTimeStartNs = bean.getCurrentThreadCpuTime();
         for (int i = 0; i < numIters; i++) {
           targets.get(target);
-          results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
+          if (prefilter) {
+            results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
+          } else {
+            results[i] =
+                doKnnVectorQuery(
+                    searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
+
+            results[i].scoreDocs =
+                Arrays.stream(results[i].scoreDocs)
+                    .filter(scoreDoc -> matchDocs == null || matchDocs.get(scoreDoc.doc))
+                    .toArray(ScoreDoc[]::new);
+          }
         }
         totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
         elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms
@@ -417,7 +463,7 @@ public class KnnGraphTester {
       totalVisited /= numIters;
       System.out.printf(
           Locale.ROOT,
-          "%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n",
+          "%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\t%.2f\t%s\n",
           recall,
           totalCpuTime / (float) numIters,
           numDocs,
@@ -425,21 +471,22 @@ public class KnnGraphTester {
           maxConn,
           beamWidth,
           totalVisited,
-          reindexTimeMsec);
+          reindexTimeMsec,
+          selectivity,
+          prefilter ? "pre-filter" : "post-filter");
     }
   }
 
   private static TopDocs doKnnVectorQuery(
-      IndexSearcher searcher, String field, float[] vector, int k, int fanout) throws IOException {
-    return searcher.search(new KnnVectorQuery(field, vector, k + fanout), k);
+      IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter)
+      throws IOException {
+    return searcher.search(new KnnVectorQuery(field, vector, k + fanout, filter), k);
   }
 
   private float checkResults(TopDocs[] results, int[][] nn) {
     int totalMatches = 0;
-    int totalResults = 0;
+    int totalResults = results.length * topK;
     for (int i = 0; i < results.length; i++) {
-      int n = results[i].scoreDocs.length;
-      totalResults += n;
       // System.out.println(Arrays.toString(nn[i]));
       // System.out.println(Arrays.toString(results[i].scoreDocs));
       totalMatches += compareNN(nn[i], results[i]);
@@ -463,7 +510,7 @@ public class KnnGraphTester {
     System.out.print('\n');
     */
     Set<Integer> expectedSet = new HashSet<>();
-    for (int i = 0; i < results.scoreDocs.length; i++) {
+    for (int i = 0; i < topK; i++) {
       expectedSet.add(expected[i]);
     }
     for (ScoreDoc scoreDoc : results.scoreDocs) {
@@ -479,11 +526,13 @@ public class KnnGraphTester {
     String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK), 36);
     String nnFileName = "nn-" + hash + ".bin";
     Path nnPath = Paths.get(nnFileName);
-    if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath)) {
+    if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) {
       return readNN(nnPath);
     } else {
       int[][] nn = computeNN(docPath, queryPath);
-      writeNN(nn, nnPath);
+      if (selectivity == 1f) {
+        writeNN(nn, nnPath);
+      }
       return nn;
     }
   }
@@ -527,6 +576,19 @@ public class KnnGraphTester {
     }
   }
 
+  @SuppressForbidden(reason = "Uses random()")
+  private static FixedBitSet generateRandomBitSet(int size, float selectivity) {
+    FixedBitSet bitSet = new FixedBitSet(size);
+    for (int i = 0; i < size; i++) {
+      if (Math.random() < selectivity) {
+        bitSet.set(i);
+      } else {
+        bitSet.clear(i);
+      }
+    }
+    return bitSet;
+  }
+
   private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
     int[][] result = new int[numIters][];
     if (quiet == false) {
@@ -558,7 +620,9 @@ public class KnnGraphTester {
           for (; j < numDocs && vectors.hasRemaining(); j++) {
             vectors.get(vector);
             float d = similarityFunction.compare(query, vector);
-            queue.insertWithOverflow(j, d);
+            if (matchDocs == null || matchDocs.get(j)) {
+              queue.insertWithOverflow(j, d);
+            }
           }
           result[i] = new int[topK];
           for (int k = topK - 1; k >= 0; k--) {
@@ -633,7 +697,7 @@ public class KnnGraphTester {
 
   private static void usage() {
     String error =
-        "Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N]";
+        "Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N] [-filterSelectivity N] [-prefilter]";
     System.err.println(error);
     System.exit(1);
   }
@@ -729,4 +793,50 @@ public class KnnGraphTester {
       return Float.compare(score[pivot], score[j]);
     }
   }
+
+  private static class BitSetQuery extends Query {
+
+    private final FixedBitSet docs;
+    private final int cardinality;
+
+    BitSetQuery(FixedBitSet docs) {
+      this.docs = docs;
+      this.cardinality = docs.cardinality();
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new ConstantScoreWeight(this, boost) {
+        @Override
+        public Scorer scorer(LeafReaderContext context) throws IOException {
+          return new ConstantScoreScorer(
+              this, score(), scoreMode, new BitSetIterator(docs, cardinality));
+        }
+
+        @Override
+        public boolean isCacheable(LeafReaderContext ctx) {
+          return false;
+        }
+      };
+    }
+
+    @Override
+    public void visit(QueryVisitor visitor) {}
+
+    @Override
+    public String toString(String field) {
+      return "BitSetQuery";
+    }
+
+    @Override
+    public boolean equals(Object other) {
+      return sameClassAs(other) && docs.equals(((BitSetQuery) other).docs);
+    }
+
+    @Override
+    public int hashCode() {
+      return 31 * classHash() + docs.hashCode();
+    }
+  }
 }


[lucene] 01/03: LUCENE-10663: Fix KnnVectorQuery explain with multiple segments (#1050)

Posted by ju...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

julietibs pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git

commit 1559de836ce347a8fd8e5ffbbb51fda14f8c16cf
Author: Shiming Li <li...@live.com>
AuthorDate: Fri Jul 29 01:31:49 2022 +0800

    LUCENE-10663: Fix KnnVectorQuery explain with multiple segments (#1050)
    
    If there are multiple segments. KnnVectorQuery explain has a bug in locating
    the doc ID. This is because the doc ID in explain is the docBase without the
    segment.  In KnnVectorQuery.DocAndScoreQuery docs docid is increased in each
    segment of the docBase. So, in the 'DocAndScoreQuery.explain', needs to be
    added with the segment's docBase.
    
    Co-authored-by: Julie Tibshirani <ju...@apache.org>
---
 lucene/CHANGES.txt                                 |  2 +-
 .../org/apache/lucene/search/KnnVectorQuery.java   |  2 +-
 .../apache/lucene/search/TestKnnVectorQuery.java   | 28 ++++++++++++++++++++++
 3 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index fbaa99d3301..09f5ef4832a 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -30,7 +30,7 @@ Optimizations
 
 Bug Fixes
 ---------------------
-(No changes)
+* LUCENE-10663: Fix KnnVectorQuery explain with multiple segments. (Shiming Li)
 
 Other
 ---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
index 6e68de193da..9fe39686ca5 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
@@ -318,7 +318,7 @@ public class KnnVectorQuery extends Query {
       return new Weight(this) {
         @Override
         public Explanation explain(LeafReaderContext context, int doc) {
-          int found = Arrays.binarySearch(docs, doc);
+          int found = Arrays.binarySearch(docs, doc + context.docBase);
           if (found < 0) {
             return Explanation.noMatch("not in top " + k);
           }
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
index 74ecf23c292..4d826126ac0 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
@@ -446,6 +446,34 @@ public class TestKnnVectorQuery extends LuceneTestCase {
     }
   }
 
+  public void testExplainMultipleSegments() throws IOException {
+    try (Directory d = newDirectory()) {
+      try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+        for (int j = 0; j < 5; j++) {
+          Document doc = new Document();
+          doc.add(new KnnVectorField("field", new float[] {j, j}));
+          w.addDocument(doc);
+          w.commit();
+        }
+      }
+      try (IndexReader reader = DirectoryReader.open(d)) {
+        IndexSearcher searcher = new IndexSearcher(reader);
+        KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+        Explanation matched = searcher.explain(query, 2);
+        assertTrue(matched.isMatch());
+        assertEquals(1 / 2f, matched.getValue());
+        assertEquals(0, matched.getDetails().length);
+        assertEquals("within top 3", matched.getDescription());
+
+        Explanation nomatch = searcher.explain(query, 4);
+        assertFalse(nomatch.isMatch());
+        assertEquals(0f, nomatch.getValue());
+        assertEquals(0, matched.getDetails().length);
+        assertEquals("not in top 3", nomatch.getDescription());
+      }
+    }
+  }
+
   /** Test that when vectors are abnormally distributed among segments, we still find the top K */
   public void testSkewedIndex() throws IOException {
     /* We have to choose the numbers carefully here so that some segment has more than the expected