You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by so...@apache.org on 2021/10/07 18:09:54 UTC

[lucene] branch main updated: LUCENE-10147: ensure that KnnVectorQuery scores are positive (#361)

This is an automated email from the ASF dual-hosted git repository.

sokolov pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b1fc0e  LUCENE-10147: ensure that KnnVectorQuery scores are positive (#361)
9b1fc0e is described below

commit 9b1fc0ecc85365b202955c4731458fce19c5ba28
Author: Michael Sokolov <so...@falutin.net>
AuthorDate: Thu Oct 7 14:09:48 2021 -0400

    LUCENE-10147: ensure that KnnVectorQuery scores are positive (#361)
---
 .../codecs/lucene90/Lucene90HnswVectorsReader.java | 15 ++--
 .../lucene/index/VectorSimilarityFunction.java     | 19 +++++
 .../java/org/apache/lucene/util/VectorUtil.java    | 11 ++-
 .../apache/lucene/search/TestKnnVectorQuery.java   | 97 +++++++++++++++++++++-
 4 files changed, 126 insertions(+), 16 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
index 9a1c16f..56dcf89 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -66,7 +66,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
   Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
     this.fieldInfos = state.fieldInfos;
 
-    int versionMeta = readMetadata(state, Lucene90HnswVectorsFormat.META_EXTENSION);
+    int versionMeta = readMetadata(state);
     long[] checksumRef = new long[1];
     boolean success = false;
     try {
@@ -93,9 +93,10 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
     checksumSeed = checksumRef[0];
   }
 
-  private int readMetadata(SegmentReadState state, String fileExtension) throws IOException {
+  private int readMetadata(SegmentReadState state) throws IOException {
     String metaFileName =
-        IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
+        IndexFileNames.segmentFileName(
+            state.segmentInfo.name, state.segmentSuffix, Lucene90HnswVectorsFormat.META_EXTENSION);
     int versionMeta = -1;
     try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName, state.context)) {
       Throwable priorE = null;
@@ -255,14 +256,10 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
             random);
     int i = 0;
     ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
-    boolean reversed = fieldEntry.similarityFunction.reversed;
     while (results.size() > 0) {
       int node = results.topNode();
-      float score = results.topScore();
+      float score = fieldEntry.similarityFunction.convertToScore(results.topScore());
       results.pop();
-      if (reversed) {
-        score = 1 / (1 + score);
-      }
       scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
     }
     // always return >= the case where we can assert == is only when there are fewer than topK
@@ -358,7 +355,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
   }
 
   /** Read the vector values from the index input. This supports both iterated and random access. */
-  private class OffHeapVectorValues extends VectorValues
+  private static class OffHeapVectorValues extends VectorValues
       implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
 
     final FieldEntry fieldEntry;
diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
index 575a843..8905d49 100644
--- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
+++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
@@ -32,6 +32,11 @@ public enum VectorSimilarityFunction {
     public float compare(float[] v1, float[] v2) {
       return squareDistance(v1, v2);
     }
+
+    @Override
+    public float convertToScore(float similarity) {
+      return 1 / (1 + similarity);
+    }
   },
 
   /** Dot product */
@@ -40,6 +45,11 @@ public enum VectorSimilarityFunction {
     public float compare(float[] v1, float[] v2) {
       return dotProduct(v1, v2);
     }
+
+    @Override
+    public float convertToScore(float similarity) {
+      return (1 + similarity) / 2;
+    }
   };
 
   /**
@@ -65,4 +75,13 @@ public enum VectorSimilarityFunction {
    * @return the value of the similarity function applied to the two vectors
    */
   public abstract float compare(float[] v1, float[] v2);
+
+  /**
+   * Converts similarity scores used (may be negative, reversed, etc) into document scores, which
+   * must be positive, with higher scores representing better matches.
+   *
+   * @param similarity the raw internal score as returned by {@link #compare(float[], float[])}.
+   * @return normalizedSimilarity
+   */
+  public abstract float convertToScore(float similarity);
 }
diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
index 38f3bd2..5149ad6 100644
--- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
@@ -115,9 +115,12 @@ public final class VectorUtil {
   /**
    * Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
    * thrown for zero vectors.
+   *
+   * @return the input array after normalization
    */
-  public static void l2normalize(float[] v) {
+  public static float[] l2normalize(float[] v) {
     l2normalize(v, true);
+    return v;
   }
 
   /**
@@ -125,9 +128,10 @@ public final class VectorUtil {
    *
    * @param v the vector to normalize
    * @param throwOnZero whether to throw an exception when <code>v</code> has all zeros
+   * @return the input array after normalization
    * @throws IllegalArgumentException when the vector is all zero and throwOnZero is true
    */
-  public static void l2normalize(float[] v, boolean throwOnZero) {
+  public static float[] l2normalize(float[] v, boolean throwOnZero) {
     double squareSum = 0.0f;
     int dim = v.length;
     for (float x : v) {
@@ -137,13 +141,14 @@ public final class VectorUtil {
       if (throwOnZero) {
         throw new IllegalArgumentException("Cannot normalize a zero-length vector");
       } else {
-        return;
+        return v;
       }
     }
     double length = Math.sqrt(squareSum);
     for (int i = 0; i < dim; i++) {
       v[i] /= length;
     }
+    return v;
   }
 
   /**
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
index db6c045..d652517 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
@@ -17,6 +17,7 @@
 package org.apache.lucene.search;
 
 import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
+import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
 import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
 import static org.apache.lucene.util.TestVectorUtil.randomVector;
 
@@ -33,8 +34,10 @@ import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.RandomIndexWriter;
 import org.apache.lucene.index.Term;
+import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util.VectorUtil;
 
 /** TestKnnVectorQuery tests KnnVectorQuery. */
 public class TestKnnVectorQuery extends LuceneTestCase {
@@ -164,12 +167,13 @@ public class TestKnnVectorQuery extends LuceneTestCase {
     }
   }
 
-  public void testScore() throws IOException {
+  public void testScoreEuclidean() throws IOException {
     try (Directory d = newDirectory()) {
       try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
         for (int j = 0; j < 5; j++) {
           Document doc = new Document();
-          doc.add(new KnnVectorField("field", new float[] {j, j}));
+          doc.add(
+              new KnnVectorField("field", new float[] {j, j}, VectorSimilarityFunction.EUCLIDEAN));
           w.addDocument(doc);
         }
       }
@@ -183,7 +187,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
 
         // prior to advancing, score is 0
         assertEquals(-1, scorer.docID());
-        expectThrows(ArrayIndexOutOfBoundsException.class, () -> scorer.score());
+        expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
 
         // test getMaxScore
         assertEquals(0, scorer.getMaxScore(-1), 0);
@@ -199,7 +203,92 @@ public class TestKnnVectorQuery extends LuceneTestCase {
         assertEquals(3, it.advance(3));
         assertEquals(1 / 2f, scorer.score(), 0);
         assertEquals(NO_MORE_DOCS, it.advance(4));
-        expectThrows(ArrayIndexOutOfBoundsException.class, () -> scorer.score());
+        expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
+      }
+    }
+  }
+
+  public void testScoreDotProduct() throws IOException {
+    try (Directory d = newDirectory()) {
+      try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+        for (int j = 1; j <= 5; j++) {
+          Document doc = new Document();
+          doc.add(
+              new KnnVectorField(
+                  "field", VectorUtil.l2normalize(new float[] {j, j * j}), DOT_PRODUCT));
+          w.addDocument(doc);
+        }
+      }
+      try (IndexReader reader = DirectoryReader.open(d)) {
+        assertEquals(1, reader.leaves().size());
+        IndexSearcher searcher = new IndexSearcher(reader);
+        KnnVectorQuery query =
+            new KnnVectorQuery("field", VectorUtil.l2normalize(new float[] {2, 3}), 3);
+        Query rewritten = query.rewrite(reader);
+        Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
+        Scorer scorer = weight.scorer(reader.leaves().get(0));
+
+        // prior to advancing, score is undefined
+        assertEquals(-1, scorer.docID());
+        expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
+
+        // test getMaxScore
+        assertEquals(0, scorer.getMaxScore(-1), 0);
+        /* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)) = 0.5, then
+         * normalized by (1 + x) /2.
+         */
+        float maxAtZero = 0.99029f;
+        assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
+
+        /* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
+         * is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
+         * normalized by (1 + x) /2
+         */
+        float expected =
+            (float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
+        assertEquals(expected, scorer.getMaxScore(2), 0);
+        assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
+
+        DocIdSetIterator it = scorer.iterator();
+        assertEquals(3, it.cost());
+        assertEquals(0, it.nextDoc());
+        // doc 0 has (1, 1)
+        assertEquals(maxAtZero, scorer.score(), 0.0001);
+        assertEquals(1, it.advance(1));
+        assertEquals(expected, scorer.score(), 0);
+        assertEquals(2, it.nextDoc());
+        // since topK was 3
+        assertEquals(NO_MORE_DOCS, it.advance(4));
+        expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
+      }
+    }
+  }
+
+  public void testScoreNegativeDotProduct() throws IOException {
+    try (Directory d = newDirectory()) {
+      try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+        Document doc = new Document();
+        doc.add(new KnnVectorField("field", new float[] {-1, 0}, DOT_PRODUCT));
+        w.addDocument(doc);
+        doc = new Document();
+        doc.add(new KnnVectorField("field", new float[] {1, 0}, DOT_PRODUCT));
+        w.addDocument(doc);
+      }
+      try (IndexReader reader = DirectoryReader.open(d)) {
+        assertEquals(1, reader.leaves().size());
+        IndexSearcher searcher = new IndexSearcher(reader);
+        KnnVectorQuery query = new KnnVectorQuery("field", new float[] {1, 0}, 2);
+        Query rewritten = query.rewrite(reader);
+        Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
+        Scorer scorer = weight.scorer(reader.leaves().get(0));
+
+        // scores are normalized to lie in [0, 1]
+        DocIdSetIterator it = scorer.iterator();
+        assertEquals(2, it.cost());
+        assertEquals(0, it.nextDoc());
+        assertEquals(0, scorer.score(), 0);
+        assertEquals(1, it.advance(1));
+        assertEquals(1, scorer.score(), 0);
       }
     }
   }