You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@lucene.apache.org by GitBox <gi...@apache.org> on 2021/08/06 20:37:23 UTC

[GitHub] [lucene] msokolov opened a new pull request #235: LUCENE-9614: add KnnVectorQuery implementation

msokolov opened a new pull request #235:
URL: https://github.com/apache/lucene/pull/235


   This is a first cut at an implementation of a query based on K-nearest neighbors with the vector search being done as part of rewrite(). A couple of quirks:
   
   1. I noticed that scores for the default similarity (Euclidean) had very low precision as they got large. Because of the reverse nature: smaller distances mean higher scores, we need to invert in order to gain compatibility with Lucene search scores. The way we were handling this was to apply an `exp(-distance)` to convert distances to scores. That's theoretically sound, but in practice anything over 100 or so was underflowing to zero and becoming indistinguishable. As a stopgap measure, I changed the behavior so that the scores returned by vector search are allowed to be negative and get set to be `-distance` for the reverse-score (Euclidean distance) case. It's in theory OK for these to be negative as long as they are not directly used as Lucene result scores. I added a further conversion in the Query implementation here that simply adds an offset of the minimum score *for this query*. This is perfectly valid for a single query, but not comparable across queries, and indeed, not
  even across the same query run on multiple indexes, so it would present problems for distributed implementations. I'm not sure what to do about this yet, and looking for suggestions.
   2. There's a clever implementation (hack?!) to deal with trying to minimize over-collection across multiple segments. Basically the idea is to optimistically collect the expected proportion of top K based on the segment size (plus a margin), and then to re-run the query if we can't prove we exhaustively searched the segment. I think it's sound, but welcome comments on that bit since it's a little exotic.
   
   Finally I don't know whether this ought to get pushed without some performance testing; I'll start working on that soon using luceneutil.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] jpountz commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
jpountz commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r688273458



##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;
+            }
+
+            /**
+             * move the implementation of docIO() into a differently-named method so we can call it
+             * from DocIDSetIterator.docID() even though this class is anonymous
+             *
+             * @return the current docid
+             */
+            private int docIdNoShadow() {
+              if (upTo == -1) {
+                return -1;
+              }
+              if (upTo >= upper) {
+                return NO_MORE_DOCS;
+              }
+              return docs[upTo] - context.docBase;
+            }
+
+            @Override
+            public int docID() {
+              return docIdNoShadow();
+            }
+          };
+        }
+
+        @Override
+        public boolean isCacheable(LeafReaderContext ctx) {

Review comment:
       This is the mental model I have. {{Weight#isCacheable}} tells whether it's legal to cache the weight, and then it's the job of the cache to tell whether caching is a good idea or not. I believe that one exception to this rule that we make is {{TermInSetQuery}} where we disable caching on clauses that use lots of memory.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] msokolov commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
msokolov commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r686244488



##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }

Review comment:
       OK, I am not entirely clear what the contract of this method is, but I'll take a stab. I think it is *not* supposed to update the state of the iterator, but rather return the highest doc after the current one that is <= the argument (like advance)?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] jtibshirani commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
jtibshirani commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r686637119



##########
File path: lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
##########
@@ -43,9 +43,9 @@ public float compare(float[] v1, float[] v2) {
   };
 
   /**
-   * If true, the scores associated with vector comparisons are in reverse order; that is, lower
-   * scores represent more similar vectors. Otherwise, if false, higher scores represent more
-   * similar vectors.
+   * If true, the scores associated with vector comparisons are nonnegative and in reverse order;
+   * that is, lower scores represent more similar vectors. Otherwise, if false, higher scores
+   * represent more similar vectors, and scores may be negative or positive.

Review comment:
       I think we can revert part of this since you went with `1 / (1 + x)`.

##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;
+            }
+
+            /**
+             * move the implementation of docIO() into a differently-named method so we can call it
+             * from DocIDSetIterator.docID() even though this class is anonymous
+             *
+             * @return the current docid
+             */
+            private int docIdNoShadow() {
+              if (upTo == -1) {
+                return -1;
+              }
+              if (upTo >= upper) {
+                return NO_MORE_DOCS;
+              }
+              return docs[upTo] - context.docBase;
+            }
+
+            @Override
+            public int docID() {
+              return docIdNoShadow();
+            }
+          };
+        }
+
+        @Override
+        public boolean isCacheable(LeafReaderContext ctx) {

Review comment:
       I'm having some trouble following the argument. Seems like a good learning opportunity about caching :)
   
   Even if the BooleanWeight wasn't marked as cacheable, I think its other individual clauses are still eligible for caching. Or is the idea that we'd like to cache the BooleanWeight itself, which may be intersecting the kNN results with an expensive clause? In this case, I guess a takeaway is that a Weight should always be marked cacheable if it's valid, even if it wouldn't actually help performance on its own.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] jtibshirani commented on pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
jtibshirani commented on pull request #235:
URL: https://github.com/apache/lucene/pull/235#issuecomment-899643050


   I noticed a couple tiny things after you merged: e48be68, 29ed390.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] jpountz commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
jpountz commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r685777011



##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }

Review comment:
       let's implement `advanceShallow` as well and make it move `upTo` forward?

##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;

Review comment:
       When can this happen? Since score() should only be called when there is a match, wouldn't it be always legal to return `scores[upTo]`?

##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;
+            }
+
+            /**
+             * move the implementation of docIO() into a differently-named method so we can call it
+             * from DocIDSetIterator.docID() even though this class is anonymous
+             *
+             * @return the current docid
+             */
+            private int docIdNoShadow() {
+              if (upTo == -1) {
+                return -1;
+              }
+              if (upTo >= upper) {
+                return NO_MORE_DOCS;
+              }
+              return docs[upTo] - context.docBase;
+            }
+
+            @Override
+            public int docID() {
+              return docIdNoShadow();
+            }
+          };
+        }
+
+        @Override
+        public boolean isCacheable(LeafReaderContext ctx) {
+          return true;
+        }
+      };
+    }
+
+    @Override
+    public String toString(String field) {
+      return KnnVectorQuery.this.toString();
+    }
+
+    @Override
+    public void visit(QueryVisitor visitor) {
+      KnnVectorQuery.this.visit(visitor);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+      if (obj instanceof DocAndScoreQuery == false) {
+        return false;
+      }
+      return Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
+          && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(Arrays.hashCode(docs), Arrays.hashCode(scores));

Review comment:
       Can you use the `classHash()` in the hash code too?

##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {

Review comment:
       nit: make it static so that it doesn't hold a reference to the KnnVectorQuery?

##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;
+            }
+
+            /**
+             * move the implementation of docIO() into a differently-named method so we can call it

Review comment:
       ```suggestion
                * move the implementation of docID() into a differently-named method so we can call it
   ```

##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;
+            }
+
+            /**
+             * move the implementation of docIO() into a differently-named method so we can call it
+             * from DocIDSetIterator.docID() even though this class is anonymous
+             *
+             * @return the current docid
+             */
+            private int docIdNoShadow() {
+              if (upTo == -1) {
+                return -1;
+              }
+              if (upTo >= upper) {
+                return NO_MORE_DOCS;
+              }
+              return docs[upTo] - context.docBase;
+            }
+
+            @Override
+            public int docID() {
+              return docIdNoShadow();
+            }
+          };
+        }
+
+        @Override
+        public boolean isCacheable(LeafReaderContext ctx) {
+          return true;
+        }
+      };
+    }
+
+    @Override
+    public String toString(String field) {
+      return KnnVectorQuery.this.toString();
+    }
+
+    @Override
+    public void visit(QueryVisitor visitor) {
+      KnnVectorQuery.this.visit(visitor);
+    }

Review comment:
       Let's give this query its own toString() and visit()?

##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {

Review comment:
       Maybe we should do a best-effort check that the reader wrapped by this searcher is effectively the same as the reader that was used for rewriting.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] msokolov merged pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
msokolov merged pull request #235:
URL: https://github.com/apache/lucene/pull/235


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] msokolov commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
msokolov commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r686280537



##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;
+            }
+
+            /**
+             * move the implementation of docIO() into a differently-named method so we can call it
+             * from DocIDSetIterator.docID() even though this class is anonymous
+             *
+             * @return the current docid
+             */
+            private int docIdNoShadow() {
+              if (upTo == -1) {
+                return -1;
+              }
+              if (upTo >= upper) {
+                return NO_MORE_DOCS;
+              }
+              return docs[upTo] - context.docBase;
+            }
+
+            @Override
+            public int docID() {
+              return docIdNoShadow();
+            }
+          };
+        }
+
+        @Override
+        public boolean isCacheable(LeafReaderContext ctx) {

Review comment:
       heh somehow I missed Adrien's last comment ... I retract my previous statement :)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] msokolov commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
msokolov commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r686734774



##########
File path: lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
##########
@@ -43,9 +43,9 @@ public float compare(float[] v1, float[] v2) {
   };
 
   /**
-   * If true, the scores associated with vector comparisons are in reverse order; that is, lower
-   * scores represent more similar vectors. Otherwise, if false, higher scores represent more
-   * similar vectors.
+   * If true, the scores associated with vector comparisons are nonnegative and in reverse order;
+   * that is, lower scores represent more similar vectors. Otherwise, if false, higher scores
+   * represent more similar vectors, and scores may be negative or positive.

Review comment:
       So the current patch did not change the `VectorSimilarity.compare` function implementations. It does the `1/1+x` normalization in `VectorsReader.search`, so I think the comment is still valid. I suppose we could impose a requirement on similarity functions that they always return [0,1] in ascending-similar order and eliminate the whole notion of reversed similarities. That would be a nice simplification to the API, and I think we've shown that it is achievable. But if we do decide that's best, I'd like to do it separately since it will touch a bunch of places that aren't changed yet in this PR.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] jtibshirani commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
jtibshirani commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r685918739



##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;
+            }
+
+            /**
+             * move the implementation of docIO() into a differently-named method so we can call it
+             * from DocIDSetIterator.docID() even though this class is anonymous
+             *
+             * @return the current docid
+             */
+            private int docIdNoShadow() {
+              if (upTo == -1) {
+                return -1;
+              }
+              if (upTo >= upper) {
+                return NO_MORE_DOCS;
+              }
+              return docs[upTo] - context.docBase;
+            }
+
+            @Override
+            public int docID() {
+              return docIdNoShadow();
+            }
+          };
+        }
+
+        @Override
+        public boolean isCacheable(LeafReaderContext ctx) {

Review comment:
       Is it helpful to cache this, given the expensive part is the rewrite?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] jpountz commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
jpountz commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r686008850



##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;
+            }
+
+            /**
+             * move the implementation of docIO() into a differently-named method so we can call it
+             * from DocIDSetIterator.docID() even though this class is anonymous
+             *
+             * @return the current docid
+             */
+            private int docIdNoShadow() {
+              if (upTo == -1) {
+                return -1;
+              }
+              if (upTo >= upper) {
+                return NO_MORE_DOCS;
+              }
+              return docs[upTo] - context.docBase;
+            }
+
+            @Override
+            public int docID() {
+              return docIdNoShadow();
+            }
+          };
+        }
+
+        @Override
+        public boolean isCacheable(LeafReaderContext ctx) {

Review comment:
       I think it is. If we returned false here, then any BooleanQuery that has a KnnVectorQuery as a clause would be considered not cacheable, even though it might have other clauses that are actually worth caching.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] msokolov commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
msokolov commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r686235524



##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;

Review comment:
       OK; I guess I wasn't totally clear on our ability to make such an assertion. I'll remove the checks and update the test I added for this edge case which I guess doesn't reflect normal usage.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] jtibshirani commented on pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
jtibshirani commented on pull request #235:
URL: https://github.com/apache/lucene/pull/235#issuecomment-896029536


   > I noticed that scores for the default similarity (Euclidean) had very low precision as they got large... The way we were handling this was to apply an `exp(-distance)` to convert distances to scores.
   
   I wonder if we could just swap in `f(x) = 1 / (1 + x)`, which decays a lot more slowly than `exp(-x)`. This maintains the nice property producing scores within [0, 1].
   
   > There's a clever implementation (hack?!) to deal with trying to minimize over-collection across multiple segments. Basically the idea is to optimistically collect the expected proportion of top K based on the segment size (plus a margin)...
   
   This is a nice idea! The binomial estimate is based on the idea that nearest vectors are randomly distributed through the index. But since segment membership is related to when a document was indexed, I wonder if it'll be common for most nearest neighbors to be found in one segment. For example, maybe we are indexing (and embedding) news articles as they're written, and our query is a news event. Would it make sense to start with a simple approach where we just collect 'k' from each segment? Then we would explore optimizations in a follow-up with benchmarks?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] msokolov commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
msokolov commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r686237845



##########
File path: lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
##########
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+  private static final TopDocs NO_RESULTS =
+      new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+  private final String field;
+  private final float[] target;
+  private final int k;
+
+  /**
+   * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+   * given field. <code>target</code> vector.
+   *
+   * @param field a field that has been indexed as a {@link KnnVectorField}.
+   * @param target the target of the search
+   * @param k the number of documents to find
+   * @throws IllegalArgumentException if <code>k</code> is less than 1
+   */
+  public KnnVectorQuery(String field, float[] target, int k) {
+    this.field = field;
+    this.target = target;
+    this.k = k;
+    if (k < 1) {
+      throw new IllegalArgumentException("k must be at least 1, got: " + k);
+    }
+  }
+
+  @Override
+  public Query rewrite(IndexReader reader) throws IOException {
+    int boundedK = Math.min(k, reader.numDocs());
+    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+    for (LeafReaderContext ctx : reader.leaves()) {
+      // Calculate kPerLeaf as an overestimate of the expected number of the closest k documents in
+      // this leaf
+      int expectedKPerLeaf = Math.max(1, boundedK * ctx.reader().numDocs() / reader.numDocs());
+      // Increase to include 3 std. deviations of a Binomial distribution.
+      int kPerLeaf = (int) (expectedKPerLeaf + 3 * Math.sqrt(expectedKPerLeaf));
+      perLeafResults[ctx.ord] = searchLeaf(ctx, kPerLeaf);
+    }
+    // Merge sort the results
+    TopDocs topK = TopDocs.merge(boundedK, perLeafResults);
+    // re-query any outlier segments (normally there should be none).
+    topK = checkForOutlierSegments(reader, topK, perLeafResults);
+    if (topK.scoreDocs.length == 0) {
+      return new MatchNoDocsQuery();
+    }
+    return createRewrittenQuery(reader, topK);
+  }
+
+  private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+    TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+    if (results == null) {
+      return NO_RESULTS;
+    }
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs checkForOutlierSegments(IndexReader reader, TopDocs topK, TopDocs[] perLeaf)
+      throws IOException {
+    int k = topK.scoreDocs.length;
+    if (k == 0) {
+      return topK;
+    }
+    float minScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
+    boolean rescored = false;
+    for (int i = 0; i < perLeaf.length; i++) {
+      if (perLeaf[i].scoreDocs[perLeaf[i].scoreDocs.length - 1].score >= minScore) {
+        // This segment's worst score was competitive; search it again, gathering full K this time
+        perLeaf[i] = searchLeaf(reader.leaves().get(i), topK.scoreDocs.length);
+        rescored = true;
+      }
+    }
+    if (rescored) {
+      return TopDocs.merge(k, perLeaf);
+    } else {
+      return topK;
+    }
+  }
+
+  private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+    int len = topK.scoreDocs.length;
+    float minScore = topK.scoreDocs[len - 1].score;
+    Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+    int[] docs = new int[len];
+    float[] scores = new float[len];
+    for (int i = 0; i < len; i++) {
+      docs[i] = topK.scoreDocs[i].doc;
+      scores[i] = topK.scoreDocs[i].score - minScore; // flip negative scores
+    }
+    int[] segmentStarts = findSegmentStarts(reader, docs);
+    return new DocAndScoreQuery(docs, scores, segmentStarts);
+  }
+
+  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+    int[] starts = new int[reader.leaves().size() + 1];
+    starts[starts.length - 1] = docs.length;
+    if (starts.length == 2) {
+      return starts;
+    }
+    int resultIndex = 0;
+    for (int i = 1; i < starts.length - 1; i++) {
+      int upper = reader.leaves().get(i).docBase;
+      resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+      if (resultIndex < 0) {
+        resultIndex = -1 - resultIndex;
+      }
+      starts[i] = resultIndex;
+    }
+    return starts;
+  }
+
+  @Override
+  public String toString(String field) {
+    return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
+  }
+
+  @Override
+  public void visit(QueryVisitor visitor) {
+    if (visitor.acceptField(field)) {
+      visitor.visitLeaf(this);
+    }
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof KnnVectorQuery
+        && ((KnnVectorQuery) obj).k == k
+        && ((KnnVectorQuery) obj).field.equals(field)
+        && Arrays.equals(((KnnVectorQuery) obj).target, target);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(field, k, Arrays.hashCode(target));
+  }
+
+  /** Caches the results of a KnnVector search: a list of docs and their scores */
+  class DocAndScoreQuery extends Query {
+
+    private final int[] docs;
+    private final float[] scores;
+    private final int[] segmentStarts;
+
+    /**
+     * Constructor
+     *
+     * @param docs the global docids of documents that match, in ascending order
+     * @param scores the scores of the matching documents
+     * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+     *     document in each segment. If a segment has no matching documents, it should be assigned
+     *     the index of the next segment that does. There should be a final entry that is always
+     *     docs.length-1.
+     */
+    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts) {
+      this.docs = docs;
+      this.scores = scores;
+      this.segmentStarts = segmentStarts;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+        throws IOException {
+      return new Weight(this) {
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) {
+          int found = Arrays.binarySearch(docs, doc);
+          if (found < 0) {
+            return Explanation.noMatch("not in top " + k);
+          }
+          return Explanation.match(scores[found], "within top " + k);
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+
+          return new Scorer(this) {
+            final int lower = segmentStarts[context.ord];
+            final int upper = segmentStarts[context.ord + 1];
+            int upTo = -1;
+
+            @Override
+            public DocIdSetIterator iterator() {
+              return new DocIdSetIterator() {
+                @Override
+                public int docID() {
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int nextDoc() {
+                  if (upTo == -1) {
+                    upTo = lower;
+                  } else {
+                    ++upTo;
+                  }
+                  return docIdNoShadow();
+                }
+
+                @Override
+                public int advance(int target) throws IOException {
+                  return slowAdvance(target);
+                }
+
+                @Override
+                public long cost() {
+                  return upper - lower;
+                }
+              };
+            }
+
+            @Override
+            public float getMaxScore(int docid) {
+              docid += context.docBase;
+              float maxScore = 0;
+              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+                maxScore = Math.max(maxScore, scores[idx]);
+              }
+              return maxScore;
+            }
+
+            @Override
+            public float score() {
+              if (upTo >= lower && upTo < upper) {
+                return scores[upTo];
+              }
+              return 0;
+            }
+
+            /**
+             * move the implementation of docIO() into a differently-named method so we can call it
+             * from DocIDSetIterator.docID() even though this class is anonymous
+             *
+             * @return the current docid
+             */
+            private int docIdNoShadow() {
+              if (upTo == -1) {
+                return -1;
+              }
+              if (upTo >= upper) {
+                return NO_MORE_DOCS;
+              }
+              return docs[upTo] - context.docBase;
+            }
+
+            @Override
+            public int docID() {
+              return docIdNoShadow();
+            }
+          };
+        }
+
+        @Override
+        public boolean isCacheable(LeafReaderContext ctx) {

Review comment:
       I'm not sure; I don't have a lot of experience using the cache, and I guess my idea was that it is *possible* to cache (cacheable)? But ... I guess we only cache the rewritten query? In which case there's really no point since this is already effectively a cache ... yeah I'll change to return false.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] jtibshirani edited a comment on pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
jtibshirani edited a comment on pull request #235:
URL: https://github.com/apache/lucene/pull/235#issuecomment-896029536


   > I noticed that scores for the default similarity (Euclidean) had very low precision as they got large... The way we were handling this was to apply an `exp(-distance)` to convert distances to scores.
   
   I wonder if we could just swap in `f(x) = 1 / (1 + x)`, which decays a lot more slowly than `exp(-x)`. This maintains the nice property of producing scores within [0, 1].
   
   > There's a clever implementation (hack?!) to deal with trying to minimize over-collection across multiple segments. Basically the idea is to optimistically collect the expected proportion of top K based on the segment size (plus a margin)...
   
   This is a nice idea! The binomial estimate is based on the idea that nearest vectors are randomly distributed through the index. But since segment membership is related to when a document was indexed, I wonder if it'll be common for most nearest neighbors to be found in one segment. For example, maybe we are indexing (and embedding) news articles as they're written, and our query is a news event. Would it make sense to start with a simple approach where we just collect 'k' from each segment? Then we would explore optimizations in a follow-up with benchmarks?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] msokolov commented on pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
msokolov commented on pull request #235:
URL: https://github.com/apache/lucene/pull/235#issuecomment-896098498


   Thanks for all the comments; I'll follow up with a new commit that addresses them soon. `1 / (1 + x)` makes a lot of sense; I was groping towards it :)
   
   Re: the random-distribution assumption for segments -- I believe this depends very much on the use case. Our experience in e-commerce is it is *usually* true. We've seen occasional outlying cases (more popular media products get re-indexed more often, and there can be correlation if *popularity* is an important query feature, which it is), but this is more the exception than the rule. OTOH a time-series index is likely to be heavily correlated, so a different strategy is appropriate (also, sequential operation can more easily re-use thresholds across segments, and if the segments can be sorted, that will help). Perhaps the vanilla approach (collect K per segment) is best as a safe first step, but I think some optimization here will be heavily impactful since the `K` directly influences the number of nodes explored in the graph, and thence the query cost. Maybe it will deserve some kind of parameterization - so yes, I agree, let's remove this for now, and follow up later.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] msokolov commented on pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
msokolov commented on pull request #235:
URL: https://github.com/apache/lucene/pull/235#issuecomment-898416859


   I ran a luceneutil test comparing the KnnQuery implementation we have there, which is implemented in createWeight rather than in rewrite and saw no difference. It's a little bit bogus as a comparison perhaps, but it's the best we have right now, and at least it proves we didn't do anything more boneheaded than before.
   
   ```
                       TaskQPS baseline      StdDevQPS candidate      StdDev                Pct diff p-value
              LowTermVector      685.43      (6.9%)      664.26      (7.1%)   -3.1% ( -15% -   11%) 0.162
           AndHighMedVector      656.70      (6.1%)      646.58      (2.9%)   -1.5% (  -9% -    7%) 0.308
             HighTermVector      703.33     (11.8%)      704.22      (5.9%)    0.1% ( -15% -   20%) 0.966
           AndHighLowVector      667.55      (7.6%)      669.66      (5.2%)    0.3% ( -11% -   14%) 0.878
                   PKLookup      187.46      (1.2%)      188.41      (0.6%)    0.5% (  -1% -    2%) 0.100
              MedTermVector      636.03      (5.9%)      645.88      (5.0%)    1.5% (  -8% -   13%) 0.371
          AndHighHighVector      642.66      (5.8%)      669.80      (6.3%)    4.2% (  -7% -   17%) 0.027
   ```
   
   By the way I also did try the pro-rating idea I had posted earlier, with mixed results - it consistently made HighTermVector better and MedTermVector worse (quite a bit like 15% less QPS), which really surprised me. But perhaps having a tiny PQ (top K = 1 say) would make the graph exploration quite a bit less efficient? It's also possible this index is skewed and the query is having to re-run a bunch of times ... Needs further investigation.
   
   Finally, I think this is ready to push. I'll push later today if there are no new issues raised.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] msokolov commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
msokolov commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r686891113



##########
File path: lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
##########
@@ -43,9 +43,9 @@ public float compare(float[] v1, float[] v2) {
   };
 
   /**
-   * If true, the scores associated with vector comparisons are in reverse order; that is, lower
-   * scores represent more similar vectors. Otherwise, if false, higher scores represent more
-   * similar vectors.
+   * If true, the scores associated with vector comparisons are nonnegative and in reverse order;
+   * that is, lower scores represent more similar vectors. Otherwise, if false, higher scores
+   * represent more similar vectors, and scores may be negative or positive.

Review comment:
       heh, in fact this edit was never needed since we never changed the similarity scores (as I mentioned above), and only the scores returned from search... so I'll indeed revert




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org


[GitHub] [lucene] jtibshirani commented on a change in pull request #235: LUCENE-9614: add KnnVectorQuery implementation

Posted by GitBox <gi...@apache.org>.
jtibshirani commented on a change in pull request #235:
URL: https://github.com/apache/lucene/pull/235#discussion_r686817275



##########
File path: lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
##########
@@ -43,9 +43,9 @@ public float compare(float[] v1, float[] v2) {
   };
 
   /**
-   * If true, the scores associated with vector comparisons are in reverse order; that is, lower
-   * scores represent more similar vectors. Otherwise, if false, higher scores represent more
-   * similar vectors.
+   * If true, the scores associated with vector comparisons are nonnegative and in reverse order;
+   * that is, lower scores represent more similar vectors. Otherwise, if false, higher scores
+   * represent more similar vectors, and scores may be negative or positive.

Review comment:
       Oh I had misunderstood this docs update. Sounds good to save your idea for a follow-up.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org