You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2022/12/15 17:19:30 UTC

[lucene] branch branch_9x updated: Fix SimpleTextKnnVectorsReader to handle changes introduced in GITHUB#12004 (#12024)

This is an automated email from the ASF dual-hosted git repository.

jpountz pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new aea6bab89a0 Fix SimpleTextKnnVectorsReader to handle changes introduced in GITHUB#12004 (#12024)
aea6bab89a0 is described below

commit aea6bab89a0445690c369d8f95cef11c13c50dec
Author: Benjamin Trent <be...@gmail.com>
AuthorDate: Thu Dec 15 08:49:47 2022 -0500

    Fix SimpleTextKnnVectorsReader to handle changes introduced in GITHUB#12004 (#12024)
---
 .../simpletext/SimpleTextKnnVectorsReader.java     | 59 ++++++++++++++++++++--
 .../org/apache/lucene/search/TestVectorScorer.java | 22 ++++++--
 2 files changed, 71 insertions(+), 10 deletions(-)

diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
index 3993e2e3bd5..4994cf692d4 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
@@ -29,6 +29,7 @@ import org.apache.lucene.index.CorruptIndexException;
 import org.apache.lucene.index.FieldInfo;
 import org.apache.lucene.index.IndexFileNames;
 import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.VectorEncoding;
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.index.VectorValues;
 import org.apache.lucene.search.DocIdSetIterator;
@@ -140,7 +141,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
     }
     IndexInput bytesSlice =
         dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
-    return new SimpleTextVectorValues(fieldEntry, bytesSlice);
+    return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding());
   }
 
   @Override
@@ -187,7 +188,42 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
   @Override
   public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
       throws IOException {
-    return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+    VectorValues values = getVectorValues(field);
+    if (target.length != values.dimension()) {
+      throw new IllegalArgumentException(
+          "vector query dimension: "
+              + target.length
+              + " differs from field dimension: "
+              + values.dimension());
+    }
+    FieldInfo info = readState.fieldInfos.fieldInfo(field);
+    VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
+    HitQueue topK = new HitQueue(k, false);
+
+    int numVisited = 0;
+    TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
+
+    int doc;
+    while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
+      if (acceptDocs != null && acceptDocs.get(doc) == false) {
+        continue;
+      }
+
+      if (numVisited >= visitedLimit) {
+        relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
+        break;
+      }
+
+      BytesRef vector = values.binaryValue();
+      float score = vectorSimilarity.compare(vector, target);
+      topK.insertWithOverflow(new ScoreDoc(doc, score));
+      numVisited++;
+    }
+    ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
+    for (int i = topScoreDocs.length - 1; i >= 0; i--) {
+      topScoreDocs[i] = topK.pop();
+    }
+    return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
   }
 
   @Override
@@ -273,16 +309,19 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
     private final IndexInput in;
     private final BytesRef binaryValue;
     private final float[][] values;
+    private final VectorEncoding vectorEncoding;
 
     int curOrd;
 
-    SimpleTextVectorValues(FieldEntry entry, IndexInput in) throws IOException {
+    SimpleTextVectorValues(FieldEntry entry, IndexInput in, VectorEncoding vectorEncoding)
+        throws IOException {
       this.entry = entry;
       this.in = in;
       values = new float[entry.size()][entry.dimension];
-      binaryValue = new BytesRef(entry.dimension * Float.BYTES);
+      binaryValue = new BytesRef(entry.dimension * vectorEncoding.byteSize);
       binaryValue.length = binaryValue.bytes.length;
       curOrd = -1;
+      this.vectorEncoding = vectorEncoding;
       readAllVectors();
     }
 
@@ -303,7 +342,17 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
 
     @Override
     public BytesRef binaryValue() {
-      ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]);
+      switch (vectorEncoding) {
+          // we know that the floats are really just byte values
+        case BYTE:
+          for (int i = 0; i < values[curOrd].length; i++) {
+            binaryValue.bytes[i + binaryValue.offset] = (byte) values[curOrd][i];
+          }
+          break;
+        case FLOAT32:
+          ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]);
+          break;
+      }
       return binaryValue;
     }
 
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java
index 979283c7701..2c72baa73d3 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java
@@ -18,6 +18,7 @@ package org.apache.lucene.search;
 
 import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
 
+import com.carrotsearch.randomizedtesting.generators.RandomPicks;
 import java.io.IOException;
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
@@ -36,13 +37,25 @@ import org.apache.lucene.util.BytesRef;
 public class TestVectorScorer extends LuceneTestCase {
 
   public void testFindAll() throws IOException {
+    VectorEncoding encoding = RandomPicks.randomFrom(random(), VectorEncoding.values());
     try (Directory indexStore =
-            getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
+            getIndexStore(
+                "field", encoding, new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
         IndexReader reader = DirectoryReader.open(indexStore)) {
       assert reader.leaves().size() == 1;
       LeafReaderContext context = reader.leaves().get(0);
       FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field");
-      VectorScorer vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
+      final VectorScorer vectorScorer;
+      switch (encoding) {
+        case BYTE:
+          vectorScorer = VectorScorer.create(context, fieldInfo, new BytesRef(new byte[] {1, 2}));
+          break;
+        case FLOAT32:
+          vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
+          break;
+        default:
+          throw new IllegalArgumentException("unexpected vector encoding: " + encoding);
+      }
 
       int numDocs = 0;
       for (int i = 0; i < reader.maxDoc(); i++) {
@@ -55,11 +68,10 @@ public class TestVectorScorer extends LuceneTestCase {
   }
 
   /** Creates a new directory and adds documents with the given vectors as kNN vector fields */
-  private Directory getIndexStore(String field, float[]... contents) throws IOException {
+  private Directory getIndexStore(String field, VectorEncoding encoding, float[]... contents)
+      throws IOException {
     Directory indexStore = newDirectory();
     RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
-    VectorEncoding encoding =
-        VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
     for (int i = 0; i < contents.length; ++i) {
       Document doc = new Document();
       if (encoding == VectorEncoding.BYTE) {