You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ju...@apache.org on 2023/02/16 20:04:06 UTC

[lucene] branch main updated: Simplify max score for kNN vector queries (#12146)

This is an automated email from the ASF dual-hosted git repository.

julietibs pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new 8340b01c3cc Simplify max score for kNN vector queries (#12146)
8340b01c3cc is described below

commit 8340b01c3cc229f33584ce2178b07b8984daa6a9
Author: Julie Tibshirani <ju...@apache.org>
AuthorDate: Thu Feb 16 12:03:59 2023 -0800

    Simplify max score for kNN vector queries (#12146)
    
    The helper class DocAndScoreQuery implements advanceShallow to help skip
    non-competitive documents. This method doesn't actually keep track of where it
    has advanced, which means it can do extra work.
    
    Overall the complexity here didn't seem worth it, given the low cost of
    collecting matching kNN docs. This PR switches to a simpler approach, which uses
    a fixed upper bound on the max score.
---
 .../lucene/search/AbstractKnnVectorQuery.java      | 29 ++++-------
 .../lucene/search/BaseKnnVectorQueryTestCase.java  | 58 +++++-----------------
 .../lucene/search/TestKnnFloatVectorQuery.java     | 24 ++++-----
 3 files changed, 32 insertions(+), 79 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index 6e348fcc5ee..f2e4b125f98 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -189,6 +189,10 @@ abstract class AbstractKnnVectorQuery extends Query {
 
   private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
     int len = topK.scoreDocs.length;
+
+    assert len > 0;
+    float maxScore = topK.scoreDocs[0].score;
+
     Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
     int[] docs = new int[len];
     float[] scores = new float[len];
@@ -197,7 +201,7 @@ abstract class AbstractKnnVectorQuery extends Query {
       scores[i] = topK.scoreDocs[i].score;
     }
     int[] segmentStarts = findSegmentStarts(reader, docs);
-    return new DocAndScoreQuery(docs, scores, segmentStarts, reader.getContext().id());
+    return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
   }
 
   private int[] findSegmentStarts(IndexReader reader, int[] docs) {
@@ -265,6 +269,7 @@ abstract class AbstractKnnVectorQuery extends Query {
 
     private final int[] docs;
     private final float[] scores;
+    private final float maxScore;
     private final int[] segmentStarts;
     private final Object contextIdentity;
 
@@ -280,9 +285,11 @@ abstract class AbstractKnnVectorQuery extends Query {
      * @param contextIdentity an object identifying the reader context that was used to build this
      *     query
      */
-    DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
+    DocAndScoreQuery(
+        int[] docs, float[] scores, float maxScore, int[] segmentStarts, Object contextIdentity) {
       this.docs = docs;
       this.scores = scores;
+      this.maxScore = maxScore;
       this.segmentStarts = segmentStarts;
       this.contextIdentity = contextIdentity;
     }
@@ -343,11 +350,6 @@ abstract class AbstractKnnVectorQuery extends Query {
 
             @Override
             public float getMaxScore(int docId) {
-              docId += context.docBase;
-              float maxScore = 0;
-              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
-                maxScore = Math.max(maxScore, scores[idx]);
-              }
               return maxScore * boost;
             }
 
@@ -356,19 +358,6 @@ abstract class AbstractKnnVectorQuery extends Query {
               return scores[upTo] * boost;
             }
 
-            @Override
-            public int advanceShallow(int docid) {
-              int start = Math.max(upTo, lower);
-              int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
-              if (docidIndex < 0) {
-                docidIndex = -1 - docidIndex;
-              }
-              if (docidIndex >= upper) {
-                return NO_MORE_DOCS;
-              }
-              return docs[docidIndex];
-            }
-
             /**
              * move the implementation of docID() into a differently-named method so we can call it
              * from DocIDSetIterator.docID() even though this class is anonymous
diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
index 88d551619a4..dbb13d9f058 100644
--- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
@@ -244,36 +244,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
     }
   }
 
-  public void testAdvanceShallow() throws IOException {
-    try (Directory d = newDirectory()) {
-      try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
-        for (int j = 0; j < 5; j++) {
-          Document doc = new Document();
-          doc.add(getKnnVectorField("field", new float[] {j, j}));
-          w.addDocument(doc);
-        }
-      }
-      try (IndexReader reader = DirectoryReader.open(d)) {
-        IndexSearcher searcher = new IndexSearcher(reader);
-        AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
-        Query dasq = query.rewrite(searcher);
-        Scorer scorer =
-            dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
-        // before advancing the iterator
-        assertEquals(1, scorer.advanceShallow(0));
-        assertEquals(1, scorer.advanceShallow(1));
-        assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
-
-        // after advancing the iterator
-        scorer.iterator().advance(2);
-        assertEquals(2, scorer.advanceShallow(0));
-        assertEquals(2, scorer.advanceShallow(2));
-        assertEquals(3, scorer.advanceShallow(3));
-        assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
-      }
-    }
-  }
-
   public void testScoreEuclidean() throws IOException {
     float[][] vectors = new float[5][];
     for (int j = 0; j < 5; j++) {
@@ -291,9 +261,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
       assertEquals(-1, scorer.docID());
       expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
 
-      // test getMaxScore
-      assertEquals(0, scorer.getMaxScore(-1), 0);
-      assertEquals(0, scorer.getMaxScore(0), 0);
       // This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
       assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
       assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
@@ -304,6 +271,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
       assertEquals(1 / 6f, scorer.score(), 0);
       assertEquals(3, it.advance(3));
       assertEquals(1 / 2f, scorer.score(), 0);
+
       assertEquals(NO_MORE_DOCS, it.advance(4));
       expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
     }
@@ -330,32 +298,30 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
         assertEquals(-1, scorer.docID());
         expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
 
-        // test getMaxScore
-        assertEquals(0, scorer.getMaxScore(-1), 0);
-        /* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
+        /* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
          * normalized by (1 + x) /2.
          */
-        float maxAtZero =
+        float score0 =
             (float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
-        assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
 
-        /* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
-         * is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
+        /* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
          * normalized by (1 + x) /2
          */
-        float expected =
+        float score1 =
             (float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
-        assertEquals(expected, scorer.getMaxScore(2), 0);
-        assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
+
+        // doc 1 happens to have the maximum score
+        assertEquals(score1, scorer.getMaxScore(2), 0.0001);
+        assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001);
 
         DocIdSetIterator it = scorer.iterator();
         assertEquals(3, it.cost());
         assertEquals(0, it.nextDoc());
         // doc 0 has (1, 1)
-        assertEquals(maxAtZero, scorer.score(), 0.0001);
+        assertEquals(score0, scorer.score(), 0.0001);
         assertEquals(1, it.advance(1));
-        assertEquals(expected, scorer.score(), 0);
-        assertEquals(2, it.nextDoc());
+        assertEquals(score1, scorer.score(), 0.0001);
+
         // since topK was 3
         assertEquals(NO_MORE_DOCS, it.advance(4));
         expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
index c4f10f874be..04f5a53b246 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
@@ -133,32 +133,30 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
         assertEquals(-1, scorer.docID());
         expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
 
-        // test getMaxScore
-        assertEquals(0, scorer.getMaxScore(-1), 0);
-        /* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
+        /* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
          * normalized by (1 + x) /2.
          */
-        float maxAtZero =
+        float score0 =
             (float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
-        assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
 
-        /* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
-         * is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
+        /* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
          * normalized by (1 + x) /2
          */
-        float expected =
+        float score1 =
             (float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
-        assertEquals(expected, scorer.getMaxScore(2), 0);
-        assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
+
+        // doc 1 happens to have the max score
+        assertEquals(score1, scorer.getMaxScore(2), 0.0001);
+        assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001);
 
         DocIdSetIterator it = scorer.iterator();
         assertEquals(3, it.cost());
         assertEquals(0, it.nextDoc());
         // doc 0 has (1, 1)
-        assertEquals(maxAtZero, scorer.score(), 0.0001);
+        assertEquals(score0, scorer.score(), 0.0001);
         assertEquals(1, it.advance(1));
-        assertEquals(expected, scorer.score(), 0);
-        assertEquals(2, it.nextDoc());
+        assertEquals(score1, scorer.score(), 0.0001);
+
         // since topK was 3
         assertEquals(NO_MORE_DOCS, it.advance(4));
         expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);