You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ju...@apache.org on 2023/02/13 22:00:23 UTC

[lucene] 01/01: Simplify max score for kNN vector queries

This is an automated email from the ASF dual-hosted git repository.

julietibs pushed a commit to branch jtibshirani/max-score
in repository https://gitbox.apache.org/repos/asf/lucene.git

commit ebbb86313f989b5be47865295fb65e937532c849
Author: Julie Tibshirani <ju...@sourcegraph.com>
AuthorDate: Mon Feb 13 13:43:10 2023 -0800

    Simplify max score for kNN vector queries
    
    The helper class DocAndScoreQuery implements advanceShallow to help skip
    non-competitive documents. This method doesn't actually keep track of where it
    has advanced, which means it can do extra work.
    
    Overall the complexity here didn't seem worth it, given the low cost of
    collecting matching kNN docs. This PR switches to a simpler approach, which uses
    a fixed upper bound on the max score.
---
 .../lucene/search/AbstractKnnVectorQuery.java      | 28 +++-------
 .../lucene/search/BaseKnnVectorQueryTestCase.java  | 60 +++++-----------------
 .../apache/lucene/search/TestKnnVectorQuery.java   | 24 ++++-----
 3 files changed, 32 insertions(+), 80 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index e6d1d00a5c0..a9f504373dc 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -192,12 +192,16 @@ abstract class AbstractKnnVectorQuery extends Query {
     Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
     int[] docs = new int[len];
     float[] scores = new float[len];
+    float maxScore = 0.0f;
     for (int i = 0; i < len; i++) {
       docs[i] = topK.scoreDocs[i].doc;
       scores[i] = topK.scoreDocs[i].score;
+      if (scores[i] > maxScore) {
+        maxScore = scores[i];
+      }
     }
     int[] segmentStarts = findSegmentStarts(reader, docs);
-    return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id());
+    return new DocAndScoreQuery(k, docs, scores, maxScore, segmentStarts, reader.getContext().id());
   }
 
   private int[] findSegmentStarts(IndexReader reader, int[] docs) {
@@ -244,6 +248,7 @@ abstract class AbstractKnnVectorQuery extends Query {
     private final int k;
     private final int[] docs;
     private final float[] scores;
+    private final float maxScore;
     private final int[] segmentStarts;
     private final Object contextIdentity;
 
@@ -261,10 +266,11 @@ abstract class AbstractKnnVectorQuery extends Query {
      *     query
      */
     DocAndScoreQuery(
-        int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
+        int k, int[] docs, float[] scores, float maxScore, int[] segmentStarts, Object contextIdentity) {
       this.k = k;
       this.docs = docs;
       this.scores = scores;
+      this.maxScore = maxScore;
       this.segmentStarts = segmentStarts;
       this.contextIdentity = contextIdentity;
     }
@@ -325,11 +331,6 @@ abstract class AbstractKnnVectorQuery extends Query {
 
             @Override
             public float getMaxScore(int docId) {
-              docId += context.docBase;
-              float maxScore = 0;
-              for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
-                maxScore = Math.max(maxScore, scores[idx]);
-              }
               return maxScore * boost;
             }
 
@@ -338,19 +339,6 @@ abstract class AbstractKnnVectorQuery extends Query {
               return scores[upTo] * boost;
             }
 
-            @Override
-            public int advanceShallow(int docid) {
-              int start = Math.max(upTo, lower);
-              int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
-              if (docidIndex < 0) {
-                docidIndex = -1 - docidIndex;
-              }
-              if (docidIndex >= upper) {
-                return NO_MORE_DOCS;
-              }
-              return docs[docidIndex];
-            }
-
             /**
              * move the implementation of docID() into a differently-named method so we can call it
              * from DocIDSetIterator.docID() even though this class is anonymous
diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
index 1e579211e38..0a015a47249 100644
--- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
@@ -220,36 +220,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
     }
   }
 
-  public void testAdvanceShallow() throws IOException {
-    try (Directory d = newDirectory()) {
-      try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
-        for (int j = 0; j < 5; j++) {
-          Document doc = new Document();
-          doc.add(getKnnVectorField("field", new float[] {j, j}));
-          w.addDocument(doc);
-        }
-      }
-      try (IndexReader reader = DirectoryReader.open(d)) {
-        IndexSearcher searcher = new IndexSearcher(reader);
-        AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
-        Query dasq = query.rewrite(searcher);
-        Scorer scorer =
-            dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
-        // before advancing the iterator
-        assertEquals(1, scorer.advanceShallow(0));
-        assertEquals(1, scorer.advanceShallow(1));
-        assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
-
-        // after advancing the iterator
-        scorer.iterator().advance(2);
-        assertEquals(2, scorer.advanceShallow(0));
-        assertEquals(2, scorer.advanceShallow(2));
-        assertEquals(3, scorer.advanceShallow(3));
-        assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
-      }
-    }
-  }
-
   public void testScoreEuclidean() throws IOException {
     float[][] vectors = new float[5][];
     for (int j = 0; j < 5; j++) {
@@ -267,10 +237,6 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
       assertEquals(-1, scorer.docID());
       expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
 
-      // test getMaxScore
-      assertEquals(0, scorer.getMaxScore(-1), 0);
-      assertEquals(0, scorer.getMaxScore(0), 0);
-      // This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
       assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
       assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
 
@@ -280,6 +246,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
       assertEquals(1 / 6f, scorer.score(), 0);
       assertEquals(3, it.advance(3));
       assertEquals(1 / 2f, scorer.score(), 0);
+
       assertEquals(NO_MORE_DOCS, it.advance(4));
       expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
     }
@@ -306,32 +273,31 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
         assertEquals(-1, scorer.docID());
         expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
 
-        // test getMaxScore
-        assertEquals(0, scorer.getMaxScore(-1), 0);
-        /* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
+        /* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
          * normalized by (1 + x) /2.
          */
-        float maxAtZero =
+        float score0 =
             (float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
-        assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
 
-        /* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
-         * is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
+        /* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
          * normalized by (1 + x) /2
          */
-        float expected =
+        float score1 =
             (float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
-        assertEquals(expected, scorer.getMaxScore(2), 0);
-        assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
+
+        // doc 1 happens to have the maximum score
+        // doc 1 happens to have the maximum score
+        assertEquals(score1, scorer.getMaxScore(2), 0.0001);
+        assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001);
 
         DocIdSetIterator it = scorer.iterator();
         assertEquals(3, it.cost());
         assertEquals(0, it.nextDoc());
         // doc 0 has (1, 1)
-        assertEquals(maxAtZero, scorer.score(), 0.0001);
+        assertEquals(score0, scorer.score(), 0.0001);
         assertEquals(1, it.advance(1));
-        assertEquals(expected, scorer.score(), 0);
-        assertEquals(2, it.nextDoc());
+        assertEquals(score1, scorer.score(), 0.0001);
+
         // since topK was 3
         assertEquals(NO_MORE_DOCS, it.advance(4));
         expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
index a751d37f0cd..459d23ea11f 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
@@ -124,32 +124,30 @@ public class TestKnnVectorQuery extends BaseKnnVectorQueryTestCase {
         assertEquals(-1, scorer.docID());
         expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
 
-        // test getMaxScore
-        assertEquals(0, scorer.getMaxScore(-1), 0);
-        /* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
+        /* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
          * normalized by (1 + x) /2.
          */
-        float maxAtZero =
+        float score0 =
             (float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
-        assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
 
-        /* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
-         * is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
+        /* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
          * normalized by (1 + x) /2
          */
-        float expected =
+        float score1 =
             (float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
-        assertEquals(expected, scorer.getMaxScore(2), 0);
-        assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
+
+        // doc 1 happens to have the max score
+        assertEquals(score1, scorer.getMaxScore(2), 0.0001);
+        assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001);
 
         DocIdSetIterator it = scorer.iterator();
         assertEquals(3, it.cost());
         assertEquals(0, it.nextDoc());
         // doc 0 has (1, 1)
-        assertEquals(maxAtZero, scorer.score(), 0.0001);
+        assertEquals(score0, scorer.score(), 0.0001);
         assertEquals(1, it.advance(1));
-        assertEquals(expected, scorer.score(), 0);
-        assertEquals(2, it.nextDoc());
+        assertEquals(score1, scorer.score(), 0.0001);
+
         // since topK was 3
         assertEquals(NO_MORE_DOCS, it.advance(4));
         expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);