You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2022/12/15 17:19:30 UTC
[lucene] branch branch_9x updated: Fix SimpleTextKnnVectorsReader to handle changes introduced in GITHUB#12004 (#12024)
This is an automated email from the ASF dual-hosted git repository.
jpountz pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/branch_9x by this push:
new aea6bab89a0 Fix SimpleTextKnnVectorsReader to handle changes introduced in GITHUB#12004 (#12024)
aea6bab89a0 is described below
commit aea6bab89a0445690c369d8f95cef11c13c50dec
Author: Benjamin Trent <be...@gmail.com>
AuthorDate: Thu Dec 15 08:49:47 2022 -0500
Fix SimpleTextKnnVectorsReader to handle changes introduced in GITHUB#12004 (#12024)
---
.../simpletext/SimpleTextKnnVectorsReader.java | 59 ++++++++++++++++++++--
.../org/apache/lucene/search/TestVectorScorer.java | 22 ++++++--
2 files changed, 71 insertions(+), 10 deletions(-)
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
index 3993e2e3bd5..4994cf692d4 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
@@ -29,6 +29,7 @@ import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
@@ -140,7 +141,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
IndexInput bytesSlice =
dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
- return new SimpleTextVectorValues(fieldEntry, bytesSlice);
+ return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding());
}
@Override
@@ -187,7 +188,42 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
- return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+ VectorValues values = getVectorValues(field);
+ if (target.length != values.dimension()) {
+ throw new IllegalArgumentException(
+ "vector query dimension: "
+ + target.length
+ + " differs from field dimension: "
+ + values.dimension());
+ }
+ FieldInfo info = readState.fieldInfos.fieldInfo(field);
+ VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
+ HitQueue topK = new HitQueue(k, false);
+
+ int numVisited = 0;
+ TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
+
+ int doc;
+ while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
+ if (acceptDocs != null && acceptDocs.get(doc) == false) {
+ continue;
+ }
+
+ if (numVisited >= visitedLimit) {
+ relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
+ break;
+ }
+
+ BytesRef vector = values.binaryValue();
+ float score = vectorSimilarity.compare(vector, target);
+ topK.insertWithOverflow(new ScoreDoc(doc, score));
+ numVisited++;
+ }
+ ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
+ for (int i = topScoreDocs.length - 1; i >= 0; i--) {
+ topScoreDocs[i] = topK.pop();
+ }
+ return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
}
@Override
@@ -273,16 +309,19 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
private final IndexInput in;
private final BytesRef binaryValue;
private final float[][] values;
+ private final VectorEncoding vectorEncoding;
int curOrd;
- SimpleTextVectorValues(FieldEntry entry, IndexInput in) throws IOException {
+ SimpleTextVectorValues(FieldEntry entry, IndexInput in, VectorEncoding vectorEncoding)
+ throws IOException {
this.entry = entry;
this.in = in;
values = new float[entry.size()][entry.dimension];
- binaryValue = new BytesRef(entry.dimension * Float.BYTES);
+ binaryValue = new BytesRef(entry.dimension * vectorEncoding.byteSize);
binaryValue.length = binaryValue.bytes.length;
curOrd = -1;
+ this.vectorEncoding = vectorEncoding;
readAllVectors();
}
@@ -303,7 +342,17 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
@Override
public BytesRef binaryValue() {
- ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]);
+ switch (vectorEncoding) {
+ // we know that the floats are really just byte values
+ case BYTE:
+ for (int i = 0; i < values[curOrd].length; i++) {
+ binaryValue.bytes[i + binaryValue.offset] = (byte) values[curOrd][i];
+ }
+ break;
+ case FLOAT32:
+ ByteBuffer.wrap(binaryValue.bytes).asFloatBuffer().get(values[curOrd]);
+ break;
+ }
return binaryValue;
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java
index 979283c7701..2c72baa73d3 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java
@@ -18,6 +18,7 @@ package org.apache.lucene.search;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import com.carrotsearch.randomizedtesting.generators.RandomPicks;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
@@ -36,13 +37,25 @@ import org.apache.lucene.util.BytesRef;
public class TestVectorScorer extends LuceneTestCase {
public void testFindAll() throws IOException {
+ VectorEncoding encoding = RandomPicks.randomFrom(random(), VectorEncoding.values());
try (Directory indexStore =
- getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
+ getIndexStore(
+ "field", encoding, new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
assert reader.leaves().size() == 1;
LeafReaderContext context = reader.leaves().get(0);
FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo("field");
- VectorScorer vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
+ final VectorScorer vectorScorer;
+ switch (encoding) {
+ case BYTE:
+ vectorScorer = VectorScorer.create(context, fieldInfo, new BytesRef(new byte[] {1, 2}));
+ break;
+ case FLOAT32:
+ vectorScorer = VectorScorer.create(context, fieldInfo, new float[] {1, 2});
+ break;
+ default:
+ throw new IllegalArgumentException("unexpected vector encoding: " + encoding);
+ }
int numDocs = 0;
for (int i = 0; i < reader.maxDoc(); i++) {
@@ -55,11 +68,10 @@ public class TestVectorScorer extends LuceneTestCase {
}
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
- private Directory getIndexStore(String field, float[]... contents) throws IOException {
+ private Directory getIndexStore(String field, VectorEncoding encoding, float[]... contents)
+ throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
- VectorEncoding encoding =
- VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
if (encoding == VectorEncoding.BYTE) {