You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2022/12/14 13:57:15 UTC
[lucene] branch branch_9x updated: Move byte vector queries into new KnnByteVectorQuery (#12004) (#12018)
This is an automated email from the ASF dual-hosted git repository.
jpountz pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/branch_9x by this push:
new d5cef1c0355 Move byte vector queries into new KnnByteVectorQuery (#12004) (#12018)
d5cef1c0355 is described below
commit d5cef1c03552223029ae02a9b8fa4613f024e34f
Author: Benjamin Trent <be...@gmail.com>
AuthorDate: Wed Dec 14 08:57:05 2022 -0500
Move byte vector queries into new KnnByteVectorQuery (#12004) (#12018)
---
lucene/CHANGES.txt | 3 +
.../lucene90/Lucene90HnswVectorsReader.java | 6 +
.../lucene91/Lucene91HnswVectorsReader.java | 6 +
.../lucene92/Lucene92HnswVectorsReader.java | 19 +-
.../lucene94/Lucene94HnswVectorsReader.java | 51 +-
.../simpletext/SimpleTextKnnVectorsReader.java | 6 +
.../lucene/codecs/BufferingKnnVectorsWriter.java | 12 +
.../org/apache/lucene/codecs/KnnVectorsFormat.java | 7 +
.../org/apache/lucene/codecs/KnnVectorsReader.java | 30 +
.../codecs/lucene95/Lucene95HnswVectorsReader.java | 47 ++
.../codecs/perfield/PerFieldKnnVectorsFormat.java | 16 +-
.../java/org/apache/lucene/index/CheckIndex.java | 26 +-
.../java/org/apache/lucene/index/CodecReader.java | 14 +
.../apache/lucene/index/DocValuesLeafReader.java | 7 +
.../org/apache/lucene/index/FilterLeafReader.java | 6 +
.../java/org/apache/lucene/index/LeafReader.java | 29 +
.../apache/lucene/index/ParallelLeafReader.java | 12 +
.../lucene/index/SlowCodecReaderWrapper.java | 7 +
.../apache/lucene/index/SortingCodecReader.java | 6 +
...ectorQuery.java => AbstractKnnVectorQuery.java} | 67 +-
.../apache/lucene/search/KnnByteVectorQuery.java | 116 +++
.../org/apache/lucene/search/KnnVectorQuery.java | 373 +--------
.../org/apache/lucene/search/VectorScorer.java | 22 +-
.../apache/lucene/util/hnsw/HnswGraphSearcher.java | 35 +-
.../lucene/index/TestSegmentToThreadMapping.java | 7 +
...rQuery.java => BaseKnnVectorQueryTestCase.java} | 289 ++-----
.../lucene/search/TestKnnByteVectorQuery.java | 97 +++
.../apache/lucene/search/TestKnnVectorQuery.java | 908 +--------------------
.../org/apache/lucene/util/TestVectorUtil.java | 11 +
.../org/apache/lucene/util/hnsw/TestHnswGraph.java | 201 +++--
.../search/highlight/TermVectorLeafReader.java | 7 +
.../apache/lucene/index/memory/MemoryIndex.java | 6 +
.../asserting/AssertingKnnVectorsFormat.java | 19 +-
.../lucene/tests/index/MergeReaderWrapper.java | 7 +
.../org/apache/lucene/tests/search/QueryUtils.java | 7 +
35 files changed, 898 insertions(+), 1584 deletions(-)
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 323f3181e40..872b6d672e6 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -46,6 +46,9 @@ API Changes
* GITHUB#11984: Improved TimeLimitBulkScorer to check the timeout at exponantial rate.
(Costin Leau)
+* GITHUB#12004: Add new KnnByteVectorQuery for querying vector fields that are encoded as BYTE. Removes the ability to
+ use KnnVectorQuery against fields encoded as BYTE (Ben Trent)
+
New Features
---------------------
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
index 8377072a07f..3c0f28706a3 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -276,6 +276,12 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
+ @Override
+ public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
index 752b87d0179..f75caec1fc9 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
@@ -266,6 +266,12 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
+ @Override
+ public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
index fd0ccf5e63c..7e0ea326980 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
@@ -40,6 +40,7 @@ import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@@ -54,13 +55,11 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
*/
public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
- private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
- this.fieldInfos = state.fieldInfos;
int versionMeta = readMetadata(state);
boolean success = false;
try {
@@ -260,18 +259,10 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
- /** Get knn graph values; used for testing */
- public HnswGraph getGraph(String field) throws IOException {
- FieldInfo info = fieldInfos.fieldInfo(field);
- if (info == null) {
- throw new IllegalArgumentException("No such field '" + field + "'");
- }
- FieldEntry entry = fields.get(field);
- if (entry != null && entry.vectorIndexLength > 0) {
- return getGraph(entry);
- } else {
- return HnswGraph.EMPTY;
- }
+ @Override
+ public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ throw new UnsupportedOperationException();
}
private HnswGraph getGraph(FieldEntry entry) throws IOException {
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
index afa5e19e45d..253b933a736 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
@@ -41,6 +41,7 @@ import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@@ -55,13 +56,11 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
*/
public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
- private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
- this.fieldInfos = state.fieldInfos;
int versionMeta = readMetadata(state);
boolean success = false;
try {
@@ -255,7 +254,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
throws IOException {
FieldEntry fieldEntry = fields.get(field);
- if (fieldEntry.size() == 0) {
+ if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
}
@@ -290,18 +289,44 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
- /** Get knn graph values; used for testing */
- public HnswGraph getGraph(String field) throws IOException {
- FieldInfo info = fieldInfos.fieldInfo(field);
- if (info == null) {
- throw new IllegalArgumentException("No such field '" + field + "'");
+ @Override
+ public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ FieldEntry fieldEntry = fields.get(field);
+
+ if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
+ return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
}
- FieldEntry entry = fields.get(field);
- if (entry != null && entry.vectorIndexLength > 0) {
- return getGraph(entry);
- } else {
- return HnswGraph.EMPTY;
+
+ // bound k by total number of vectors to prevent oversizing data structures
+ k = Math.min(k, fieldEntry.size());
+ OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
+
+ NeighborQueue results =
+ HnswGraphSearcher.search(
+ target,
+ k,
+ vectorValues,
+ fieldEntry.vectorEncoding,
+ fieldEntry.similarityFunction,
+ getGraph(fieldEntry),
+ vectorValues.getAcceptOrds(acceptDocs),
+ visitedLimit);
+
+ int i = 0;
+ ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
+ while (results.size() > 0) {
+ int node = results.topNode();
+ float score = results.topScore();
+ results.pop();
+ scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
}
+
+ TotalHits.Relation relation =
+ results.incomplete()
+ ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
+ : TotalHits.Relation.EQUAL_TO;
+ return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
private HnswGraph getGraph(FieldEntry entry) throws IOException {
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
index a3d7753871b..3993e2e3bd5 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
@@ -184,6 +184,12 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
}
+ @Override
+ public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+ }
+
@Override
public void checkIntegrity() throws IOException {
IndexInput clone = dataIn.clone();
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java
index 59288b127a6..6010918fb57 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java
@@ -90,6 +90,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
+
+ @Override
+ public TopDocs search(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
+ throw new UnsupportedOperationException();
+ }
};
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
@@ -185,6 +191,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
throw new UnsupportedOperationException();
}
+ @Override
+ public TopDocs search(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public VectorValues getVectorValues(String field) throws IOException {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java
index 2fa4628aa77..96a79cde891 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java
@@ -23,6 +23,7 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NamedSPILoader;
/**
@@ -103,6 +104,12 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
throw new UnsupportedOperationException();
}
+ @Override
+ public TopDocs search(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public void close() {}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
index 711e5e59c19..6741674e01d 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
@@ -26,6 +26,7 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
/** Reads vectors from an index. */
public abstract class KnnVectorsReader implements Closeable, Accountable {
@@ -80,6 +81,35 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
public abstract TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
+ /**
+ * Return the k nearest neighbor documents as determined by comparison of their vector values for
+ * this field, to the given vector, by the field's similarity function. The score of each document
+ * is derived from the vector similarity in a way that ensures scores are positive and that a
+ * larger score corresponds to a higher ranking.
+ *
+ * <p>The search is allowed to be approximate, meaning the results are not guaranteed to be the
+ * true k closest neighbors. For large values of k (for example when k is close to the total
+ * number of documents), the search may also retrieve fewer than k documents.
+ *
+ * <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
+ * order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
+ * contains the number of documents visited during the search. If the search stopped early because
+ * it hit {@code visitedLimit}, it is indicated through the relation {@code
+ * TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
+ *
+ * <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
+ * FieldInfo}. The return value is never {@code null}.
+ *
+ * @param field the vector field to search
+ * @param target the vector-valued query
+ * @param k the number of docs to return
+ * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
+ * if they are all allowed to match.
+ * @param visitedLimit the maximum number of nodes that the search is allowed to visit
+ * @return the k nearest neighbor documents, along with their (similarity-specific) scores.
+ */
+ public abstract TopDocs search(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Returns an instance optimized for merging. This instance may only be consumed in the thread
* that called {@link #getMergeInstance()}.
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java
index d670e411a11..e22d01cb06c 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java
@@ -43,6 +43,7 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
@@ -261,6 +262,52 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
if (fieldEntry.size() == 0) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
}
+ if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
+ return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+ }
+
+ // bound k by total number of vectors to prevent oversizing data structures
+ k = Math.min(k, fieldEntry.size());
+ OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
+
+ NeighborQueue results =
+ HnswGraphSearcher.search(
+ target,
+ k,
+ vectorValues,
+ fieldEntry.vectorEncoding,
+ fieldEntry.similarityFunction,
+ getGraph(fieldEntry),
+ vectorValues.getAcceptOrds(acceptDocs),
+ visitedLimit);
+
+ int i = 0;
+ ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
+ while (results.size() > 0) {
+ int node = results.topNode();
+ float score = results.topScore();
+ results.pop();
+ scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score);
+ }
+
+ TotalHits.Relation relation =
+ results.incomplete()
+ ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
+ : TotalHits.Relation.EQUAL_TO;
+ return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
+ }
+
+ @Override
+ public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ FieldEntry fieldEntry = fields.get(field);
+
+ if (fieldEntry.size() == 0) {
+ return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+ }
+ if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
+ return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+ }
// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
index e30d24fbc8e..428311e92e4 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
@@ -33,10 +33,9 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
-import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
-import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
/**
@@ -259,12 +258,13 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
- KnnVectorsReader knnVectorsReader = fields.get(field);
- if (knnVectorsReader == null) {
- return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
- } else {
- return knnVectorsReader.search(field, target, k, acceptDocs, visitedLimit);
- }
+ return fields.get(field).search(field, target, k, acceptDocs, visitedLimit);
+ }
+
+ @Override
+ public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ return fields.get(field).search(field, target, k, acceptDocs, visitedLimit);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
index 86785b65460..b9495832ba2 100644
--- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
+++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
@@ -2598,18 +2598,34 @@ public final class CheckIndex implements Closeable {
int docCount = 0;
int everyNdoc = Math.max(values.size() / 64, 1);
while (values.nextDoc() != NO_MORE_DOCS) {
- float[] vectorValue = values.vectorValue();
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
- TopDocs docs =
- reader
- .getVectorReader()
- .search(fieldInfo.name, vectorValue, 10, null, Integer.MAX_VALUE);
+ final TopDocs docs;
+ switch (fieldInfo.getVectorEncoding()) {
+ case FLOAT32:
+ docs =
+ reader
+ .getVectorReader()
+ .search(
+ fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
+ break;
+ case BYTE:
+ docs =
+ reader
+ .getVectorReader()
+ .search(
+ fieldInfo.name, values.binaryValue(), 10, null, Integer.MAX_VALUE);
+ break;
+ default:
+ throw new IllegalArgumentException(
+ "unknown vector encoding: " + fieldInfo.getVectorEncoding());
+ }
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
}
}
+ float[] vectorValue = values.vectorValue();
int valueLength = vectorValue.length;
if (valueLength != dimension) {
throw new CheckIndexException(
diff --git a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java
index 7220756d0cf..70a4afd014b 100644
--- a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java
@@ -27,6 +27,7 @@ import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
/** LeafReader implemented by codec APIs. */
public abstract class CodecReader extends LeafReader {
@@ -253,6 +254,19 @@ public abstract class CodecReader extends LeafReader {
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
}
+ @Override
+ public final TopDocs searchNearestVectors(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
+ ensureOpen();
+ FieldInfo fi = getFieldInfos().fieldInfo(field);
+ if (fi == null || fi.getVectorDimension() == 0) {
+ // Field does not exist or does not index vectors
+ return null;
+ }
+
+ return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
+ }
+
@Override
protected void doClose() throws IOException {}
diff --git a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java
index 31b50d4f381..99ce3bcd980 100644
--- a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java
@@ -20,6 +20,7 @@ package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
abstract class DocValuesLeafReader extends LeafReader {
@Override
@@ -58,6 +59,12 @@ abstract class DocValuesLeafReader extends LeafReader {
throw new UnsupportedOperationException();
}
+ @Override
+ public TopDocs searchNearestVectors(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public final void checkIntegrity() throws IOException {
throw new UnsupportedOperationException();
diff --git a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
index 11f8c223352..cedec13a1d0 100644
--- a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
@@ -357,6 +357,12 @@ public abstract class FilterLeafReader extends LeafReader {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
+ @Override
+ public TopDocs searchNearestVectors(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
+ return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
+ }
+
@Override
public Fields getTermVectors(int docID) throws IOException {
ensureOpen();
diff --git a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java
index 017c085b540..e71ff1f3edf 100644
--- a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java
@@ -21,6 +21,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
/**
* {@code LeafReader} is an abstract class, providing an interface for accessing an index. Search of
@@ -235,6 +236,34 @@ public abstract class LeafReader extends IndexReader {
public abstract TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
+ /**
+ * Return the k nearest neighbor documents as determined by comparison of their vector values for
+ * this field, to the given vector, by the field's similarity function. The score of each document
+ * is derived from the vector similarity in a way that ensures scores are positive and that a
+ * larger score corresponds to a higher ranking.
+ *
+ * <p>The search is allowed to be approximate, meaning the results are not guaranteed to be the
+ * true k closest neighbors. For large values of k (for example when k is close to the total
+ * number of documents), the search may also retrieve fewer than k documents.
+ *
+ * <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
+ * sorted in order of their similarity to the query vector (decreasing scores). The {@link
+ * TotalHits} contains the number of documents visited during the search. If the search stopped
+ * early because it hit {@code visitedLimit}, it is indicated through the relation {@code
+ * TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
+ *
+ * @param field the vector field to search
+ * @param target the vector-valued query
+ * @param k the number of docs to return
+ * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
+ * if they are all allowed to match.
+ * @param visitedLimit the maximum number of nodes that the search is allowed to visit
+ * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
+ * @lucene.experimental
+ */
+ public abstract TopDocs searchNearestVectors(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
+
/**
* Get the {@link FieldInfos} describing all fields in this reader.
*
diff --git a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java
index 006127375ec..8f13f91ea28 100644
--- a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java
@@ -29,6 +29,7 @@ import java.util.TreeMap;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;
/**
@@ -452,6 +453,17 @@ public class ParallelLeafReader extends LeafReader {
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
}
+ @Override
+ public TopDocs searchNearestVectors(
+ String fieldName, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ ensureOpen();
+ LeafReader reader = fieldToReader.get(fieldName);
+ return reader == null
+ ? null
+ : reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
+ }
+
@Override
public void checkIntegrity() throws IOException {
ensureOpen();
diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java
index 8bdc1497433..912bb54d5cf 100644
--- a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java
+++ b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java
@@ -30,6 +30,7 @@ import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
/**
* Wraps arbitrary readers for merging. Note that this can cause slow and memory-intensive merges.
@@ -173,6 +174,12 @@ public final class SlowCodecReaderWrapper {
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
+ @Override
+ public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
+ }
+
@Override
public void checkIntegrity() {
// We already checkIntegrity the entire reader up front
diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
index ded3360dee6..1444c97c7d1 100644
--- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
@@ -476,6 +476,12 @@ public final class SortingCodecReader extends FilterCodecReader {
throw new UnsupportedOperationException();
}
+ @Override
+ public TopDocs search(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public void close() throws IOException {
delegate.close();
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
similarity index 85%
copy from lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
copy to lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index 883e2b416ea..c5060c5c694 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -23,7 +23,6 @@ import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
-import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
@@ -43,41 +42,16 @@ import org.apache.lucene.util.Bits;
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
* </ul>
*/
-public class KnnVectorQuery extends Query {
+abstract class AbstractKnnVectorQuery extends Query {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
- private final String field;
- private final float[] target;
- private final int k;
+ protected final String field;
+ protected final int k;
private final Query filter;
- /**
- * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
- * given field. <code>target</code> vector.
- *
- * @param field a field that has been indexed as a {@link KnnVectorField}.
- * @param target the target of the search
- * @param k the number of documents to find
- * @throws IllegalArgumentException if <code>k</code> is less than 1
- */
- public KnnVectorQuery(String field, float[] target, int k) {
- this(field, target, k, null);
- }
-
- /**
- * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
- * given field. <code>target</code> vector.
- *
- * @param field a field that has been indexed as a {@link KnnVectorField}.
- * @param target the target of the search
- * @param k the number of documents to find
- * @param filter a filter applied before the vector search
- * @throws IllegalArgumentException if <code>k</code> is less than 1
- */
- public KnnVectorQuery(String field, float[] target, int k, Query filter) {
+ public AbstractKnnVectorQuery(String field, int k, Query filter) {
this.field = field;
- this.target = target;
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
@@ -168,12 +142,11 @@ public class KnnVectorQuery extends Query {
}
}
- private TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
- throws IOException {
- TopDocs results =
- context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
- return results != null ? results : NO_RESULTS;
- }
+ protected abstract TopDocs approximateSearch(
+ LeafReaderContext context, Bits acceptDocs, int visitedLimit) throws IOException;
+
+ abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi)
+ throws IOException;
// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
@@ -184,7 +157,7 @@ public class KnnVectorQuery extends Query {
return NO_RESULTS;
}
- VectorScorer vectorScorer = VectorScorer.create(context, fi, target);
+ VectorScorer vectorScorer = createVectorScorer(context, fi);
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
@@ -245,11 +218,6 @@ public class KnnVectorQuery extends Query {
return starts;
}
- @Override
- public String toString(String field) {
- return getClass().getSimpleName() + ":" + this.field + "[" + target[0] + ",...][" + k + "]";
- }
-
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
@@ -258,19 +226,16 @@ public class KnnVectorQuery extends Query {
}
@Override
- public boolean equals(Object obj) {
- if (sameClassAs(obj) == false) {
- return false;
- }
- return ((KnnVectorQuery) obj).k == k
- && ((KnnVectorQuery) obj).field.equals(field)
- && Arrays.equals(((KnnVectorQuery) obj).target, target)
- && Objects.equals(filter, ((KnnVectorQuery) obj).filter);
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ AbstractKnnVectorQuery that = (AbstractKnnVectorQuery) o;
+ return k == that.k && Objects.equals(field, that.field) && Objects.equals(filter, that.filter);
}
@Override
public int hashCode() {
- return Objects.hash(classHash(), field, k, Arrays.hashCode(target), filter);
+ return Objects.hash(field, k, filter);
}
/** Caches the results of a KnnVector search: a list of docs and their scores */
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
new file mode 100644
index 00000000000..0a36be780ae
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * Uses {@link KnnVectorsReader#search(String, BytesRef, int, Bits, int)} to perform nearest
+ * neighbour search.
+ *
+ * <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
+ * executes the filter for each leaf, then chooses a strategy dynamically:
+ *
+ * <ul>
+ * <li>If the filter cost is less than k, just execute an exact search
+ * <li>Otherwise run a kNN search subject to the filter
+ * <li>If the kNN search visits too many vectors without completing, stop and run an exact search
+ * </ul>
+ */
+public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
+
+ private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
+
+ private final BytesRef target;
+
+ /**
+ * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+ * given field. <code>target</code> vector.
+ *
+ * @param field a field that has been indexed as a {@link KnnVectorField}.
+ * @param target the target of the search
+ * @param k the number of documents to find
+ * @throws IllegalArgumentException if <code>k</code> is less than 1
+ */
+ public KnnByteVectorQuery(String field, byte[] target, int k) {
+ this(field, target, k, null);
+ }
+
+ /**
+ * Find the <code>k</code> nearest documents to the target vector according to the vectors in the
+ * given field. <code>target</code> vector.
+ *
+ * @param field a field that has been indexed as a {@link KnnVectorField}.
+ * @param target the target of the search
+ * @param k the number of documents to find
+ * @param filter a filter applied before the vector search
+ * @throws IllegalArgumentException if <code>k</code> is less than 1
+ */
+ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
+ super(field, k, filter);
+ this.target = new BytesRef(target);
+ }
+
+ @Override
+ protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ TopDocs results =
+ context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
+ return results != null ? results : NO_RESULTS;
+ }
+
+ @Override
+ VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
+ if (fi.getVectorEncoding() != VectorEncoding.BYTE) {
+ return null;
+ }
+ return VectorScorer.create(context, fi, target);
+ }
+
+ @Override
+ public String toString(String field) {
+ return getClass().getSimpleName()
+ + ":"
+ + this.field
+ + "["
+ + target.bytes[target.offset]
+ + ",...]["
+ + k
+ + "]";
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (super.equals(o) == false) return false;
+ KnnByteVectorQuery that = (KnnByteVectorQuery) o;
+ return Objects.equals(target, that.target);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), target);
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
index 883e2b416ea..5ed250be0b3 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
@@ -16,23 +16,18 @@
*/
package org.apache.lucene.search;
-import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
-
import java.io.IOException;
import java.util.Arrays;
-import java.util.Comparator;
-import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
-import org.apache.lucene.util.BitSet;
-import org.apache.lucene.util.BitSetIterator;
+import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.Bits;
/**
- * Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
+ * Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest
+ * neighbour search.
*
* <p>This query also allows for performing a kNN search subject to a filter. In this case, it first
* executes the filter for each leaf, then chooses a strategy dynamically:
@@ -43,14 +38,11 @@ import org.apache.lucene.util.Bits;
* <li>If the kNN search visits too many vectors without completing, stop and run an exact search
* </ul>
*/
-public class KnnVectorQuery extends Query {
+public class KnnVectorQuery extends AbstractKnnVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
- private final String field;
private final float[] target;
- private final int k;
- private final Query filter;
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
@@ -76,173 +68,24 @@ public class KnnVectorQuery extends Query {
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnVectorQuery(String field, float[] target, int k, Query filter) {
- this.field = field;
+ super(field, k, filter);
this.target = target;
- this.k = k;
- if (k < 1) {
- throw new IllegalArgumentException("k must be at least 1, got: " + k);
- }
- this.filter = filter;
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
-
- Weight filterWeight = null;
- if (filter != null) {
- IndexSearcher indexSearcher = new IndexSearcher(reader);
- BooleanQuery booleanQuery =
- new BooleanQuery.Builder()
- .add(filter, BooleanClause.Occur.FILTER)
- .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
- .build();
- Query rewritten = indexSearcher.rewrite(booleanQuery);
- filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
- }
-
- for (LeafReaderContext ctx : reader.leaves()) {
- TopDocs results = searchLeaf(ctx, filterWeight);
- if (ctx.docBase > 0) {
- for (ScoreDoc scoreDoc : results.scoreDocs) {
- scoreDoc.doc += ctx.docBase;
- }
- }
- perLeafResults[ctx.ord] = results;
- }
- // Merge sort the results
- TopDocs topK = TopDocs.merge(k, perLeafResults);
- if (topK.scoreDocs.length == 0) {
- return new MatchNoDocsQuery();
- }
- return createRewrittenQuery(reader, topK);
- }
-
- private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
- Bits liveDocs = ctx.reader().getLiveDocs();
- int maxDoc = ctx.reader().maxDoc();
-
- if (filterWeight == null) {
- return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
- }
-
- Scorer scorer = filterWeight.scorer(ctx);
- if (scorer == null) {
- return NO_RESULTS;
- }
-
- BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
- int cost = acceptDocs.cardinality();
-
- if (cost <= k) {
- // If there are <= k possible matches, short-circuit and perform exact search, since HNSW
- // must always visit at least k documents
- return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
- }
-
- // Perform the approximate kNN search
- TopDocs results = approximateSearch(ctx, acceptDocs, cost);
- if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
- return results;
- } else {
- // We stopped the kNN search because it visited too many nodes, so fall back to exact search
- return exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
- }
- }
-
- private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
- throws IOException {
- if (liveDocs == null && iterator instanceof BitSetIterator) {
- // If we already have a BitSet and no deletions, reuse the BitSet
- return ((BitSetIterator) iterator).getBitSet();
- } else {
- // Create a new BitSet from matching and live docs
- FilteredDocIdSetIterator filterIterator =
- new FilteredDocIdSetIterator(iterator) {
- @Override
- protected boolean match(int doc) {
- return liveDocs == null || liveDocs.get(doc);
- }
- };
- return BitSet.of(filterIterator, maxDoc);
- }
- }
-
- private TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
+ protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
throws IOException {
TopDocs results =
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
return results != null ? results : NO_RESULTS;
}
- // We allow this to be overridden so that tests can check what search strategy is used
- protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
- throws IOException {
- FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
- if (fi == null || fi.getVectorDimension() == 0) {
- // The field does not exist or does not index vectors
- return NO_RESULTS;
- }
-
- VectorScorer vectorScorer = VectorScorer.create(context, fi, target);
- HitQueue queue = new HitQueue(k, true);
- ScoreDoc topDoc = queue.top();
- int doc;
- while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
- boolean advanced = vectorScorer.advanceExact(doc);
- assert advanced;
-
- float score = vectorScorer.score();
- if (score > topDoc.score) {
- topDoc.score = score;
- topDoc.doc = doc;
- topDoc = queue.updateTop();
- }
- }
-
- // Remove any remaining sentinel values
- while (queue.size() > 0 && queue.top().score < 0) {
- queue.pop();
- }
-
- ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
- for (int i = topScoreDocs.length - 1; i >= 0; i--) {
- topScoreDocs[i] = queue.pop();
- }
-
- TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
- return new TopDocs(totalHits, topScoreDocs);
- }
-
- private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
- int len = topK.scoreDocs.length;
- Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
- int[] docs = new int[len];
- float[] scores = new float[len];
- for (int i = 0; i < len; i++) {
- docs[i] = topK.scoreDocs[i].doc;
- scores[i] = topK.scoreDocs[i].score;
- }
- int[] segmentStarts = findSegmentStarts(reader, docs);
- return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id());
- }
-
- private int[] findSegmentStarts(IndexReader reader, int[] docs) {
- int[] starts = new int[reader.leaves().size() + 1];
- starts[starts.length - 1] = docs.length;
- if (starts.length == 2) {
- return starts;
- }
- int resultIndex = 0;
- for (int i = 1; i < starts.length - 1; i++) {
- int upper = reader.leaves().get(i).docBase;
- resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
- if (resultIndex < 0) {
- resultIndex = -1 - resultIndex;
- }
- starts[i] = resultIndex;
+ @Override
+ VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
+ if (fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
+ return null;
}
- return starts;
+ return VectorScorer.create(context, fi, target);
}
@Override
@@ -251,195 +94,17 @@ public class KnnVectorQuery extends Query {
}
@Override
- public void visit(QueryVisitor visitor) {
- if (visitor.acceptField(field)) {
- visitor.visitLeaf(this);
- }
- }
-
- @Override
- public boolean equals(Object obj) {
- if (sameClassAs(obj) == false) {
- return false;
- }
- return ((KnnVectorQuery) obj).k == k
- && ((KnnVectorQuery) obj).field.equals(field)
- && Arrays.equals(((KnnVectorQuery) obj).target, target)
- && Objects.equals(filter, ((KnnVectorQuery) obj).filter);
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (super.equals(o) == false) return false;
+ KnnVectorQuery that = (KnnVectorQuery) o;
+ return Arrays.equals(target, that.target);
}
@Override
public int hashCode() {
- return Objects.hash(classHash(), field, k, Arrays.hashCode(target), filter);
- }
-
- /** Caches the results of a KnnVector search: a list of docs and their scores */
- static class DocAndScoreQuery extends Query {
-
- private final int k;
- private final int[] docs;
- private final float[] scores;
- private final int[] segmentStarts;
- private final Object contextIdentity;
-
- /**
- * Constructor
- *
- * @param k the number of documents requested
- * @param docs the global docids of documents that match, in ascending order
- * @param scores the scores of the matching documents
- * @param segmentStarts the indexes in docs and scores corresponding to the first matching
- * document in each segment. If a segment has no matching documents, it should be assigned
- * the index of the next segment that does. There should be a final entry that is always
- * docs.length-1.
- * @param contextIdentity an object identifying the reader context that was used to build this
- * query
- */
- DocAndScoreQuery(
- int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
- this.k = k;
- this.docs = docs;
- this.scores = scores;
- this.segmentStarts = segmentStarts;
- this.contextIdentity = contextIdentity;
- }
-
- @Override
- public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
- throws IOException {
- if (searcher.getIndexReader().getContext().id() != contextIdentity) {
- throw new IllegalStateException("This DocAndScore query was created by a different reader");
- }
- return new Weight(this) {
- @Override
- public Explanation explain(LeafReaderContext context, int doc) {
- int found = Arrays.binarySearch(docs, doc + context.docBase);
- if (found < 0) {
- return Explanation.noMatch("not in top " + k);
- }
- return Explanation.match(scores[found] * boost, "within top " + k);
- }
-
- @Override
- public Scorer scorer(LeafReaderContext context) {
-
- return new Scorer(this) {
- final int lower = segmentStarts[context.ord];
- final int upper = segmentStarts[context.ord + 1];
- int upTo = -1;
-
- @Override
- public DocIdSetIterator iterator() {
- return new DocIdSetIterator() {
- @Override
- public int docID() {
- return docIdNoShadow();
- }
-
- @Override
- public int nextDoc() {
- if (upTo == -1) {
- upTo = lower;
- } else {
- ++upTo;
- }
- return docIdNoShadow();
- }
-
- @Override
- public int advance(int target) throws IOException {
- return slowAdvance(target);
- }
-
- @Override
- public long cost() {
- return upper - lower;
- }
- };
- }
-
- @Override
- public float getMaxScore(int docId) {
- docId += context.docBase;
- float maxScore = 0;
- for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
- maxScore = Math.max(maxScore, scores[idx]);
- }
- return maxScore * boost;
- }
-
- @Override
- public float score() {
- return scores[upTo] * boost;
- }
-
- @Override
- public int advanceShallow(int docid) {
- int start = Math.max(upTo, lower);
- int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
- if (docidIndex < 0) {
- docidIndex = -1 - docidIndex;
- }
- if (docidIndex >= upper) {
- return NO_MORE_DOCS;
- }
- return docs[docidIndex];
- }
-
- /**
- * move the implementation of docID() into a differently-named method so we can call it
- * from DocIDSetIterator.docID() even though this class is anonymous
- *
- * @return the current docid
- */
- private int docIdNoShadow() {
- if (upTo == -1) {
- return -1;
- }
- if (upTo >= upper) {
- return NO_MORE_DOCS;
- }
- return docs[upTo] - context.docBase;
- }
-
- @Override
- public int docID() {
- return docIdNoShadow();
- }
- };
- }
-
- @Override
- public boolean isCacheable(LeafReaderContext ctx) {
- return true;
- }
- };
- }
-
- @Override
- public String toString(String field) {
- return "DocAndScore[" + k + "]";
- }
-
- @Override
- public void visit(QueryVisitor visitor) {
- visitor.visitLeaf(this);
- }
-
- @Override
- public boolean equals(Object obj) {
- if (sameClassAs(obj) == false) {
- return false;
- }
- return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity
- && Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
- && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(
- classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
- }
+ int result = super.hashCode();
+ result = 31 * result + Arrays.hashCode(target);
+ return result;
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java b/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java
index 6a8850c69ec..eadcdf536b6 100644
--- a/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java
@@ -22,7 +22,6 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
-import org.apache.lucene.util.VectorUtil;
/**
* Computes the similarity score between a given query vector and different document vectors. This
@@ -40,17 +39,18 @@ abstract class VectorScorer {
* @param fi the FieldInfo for the field containing document vectors
* @param query the query vector to compute the similarity for
*/
- static VectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
+ static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
+ throws IOException {
+ VectorValues values = context.reader().getVectorValues(fi.name);
+ final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
+ return new FloatVectorScorer(values, query, similarity);
+ }
+
+ static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query)
throws IOException {
VectorValues values = context.reader().getVectorValues(fi.name);
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
- switch (fi.getVectorEncoding()) {
- case BYTE:
- return new ByteVectorScorer(values, query, similarity);
- default:
- case FLOAT32:
- return new FloatVectorScorer(values, query, similarity);
- }
+ return new ByteVectorScorer(values, query, similarity);
}
VectorScorer(VectorValues values, VectorSimilarityFunction similarity) {
@@ -77,9 +77,9 @@ abstract class VectorScorer {
private final BytesRef query;
protected ByteVectorScorer(
- VectorValues values, float[] query, VectorSimilarityFunction similarity) {
+ VectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
super(values, similarity);
- this.query = VectorUtil.toBytesRef(query);
+ this.query = query;
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index db71d19f458..a2650f6e5d9 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -18,7 +18,6 @@
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
-import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException;
import org.apache.lucene.index.VectorEncoding;
@@ -96,17 +95,6 @@ public class HnswGraphSearcher<T> {
+ " differs from field dimension: "
+ vectors.dimension());
}
- if (vectorEncoding == VectorEncoding.BYTE) {
- return search(
- toBytesRef(query),
- topK,
- vectors,
- vectorEncoding,
- similarityFunction,
- graph,
- acceptOrds,
- visitedLimit);
- }
HnswGraphSearcher<float[]> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
@@ -132,7 +120,21 @@ public class HnswGraphSearcher<T> {
return results;
}
- private static NeighborQueue search(
+ /**
+ * Searches HNSW graph for the nearest neighbors of a query vector.
+ *
+ * @param query search query vector
+ * @param topK the number of nodes to be returned
+ * @param vectors the vector values
+ * @param similarityFunction the similarity function to compare vectors
+ * @param graph the graph values. May represent the entire graph, or a level in a hierarchical
+ * graph.
+ * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
+ * {@code null} if they are all allowed to match.
+ * @param visitedLimit the maximum number of nodes that the search is allowed to visit
+ * @return a priority queue holding the closest neighbors found
+ */
+ public static NeighborQueue search(
BytesRef query,
int topK,
RandomAccessVectorValues vectors,
@@ -142,6 +144,13 @@ public class HnswGraphSearcher<T> {
Bits acceptOrds,
int visitedLimit)
throws IOException {
+ if (query.length != vectors.dimension()) {
+ throw new IllegalArgumentException(
+ "vector query dimension: "
+ + query.length
+ + " differs from field dimension: "
+ + vectors.dimension());
+ }
HnswGraphSearcher<BytesRef> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java
index 62d62d538a0..279456a442e 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java
@@ -33,6 +33,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.Version;
@@ -122,6 +123,12 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
return null;
}
+ @Override
+ public TopDocs searchNearestVectors(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
+ return null;
+ }
+
@Override
protected void doClose() {}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
similarity index 73%
copy from lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
copy to lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
index ba8fb8c91e3..2363f207757 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
@@ -18,10 +18,7 @@ package org.apache.lucene.search;
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
-import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
-import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
-import static org.apache.lucene.util.TestVectorUtil.randomVector;
import java.io.IOException;
import java.util.HashSet;
@@ -29,7 +26,6 @@ import java.util.Set;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.IntPoint;
-import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
@@ -43,6 +39,7 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
@@ -50,51 +47,63 @@ import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
-import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
-import org.apache.lucene.util.VectorUtil;
-/** TestKnnVectorQuery tests KnnVectorQuery. */
-public class TestKnnVectorQuery extends LuceneTestCase {
+/** Test cases for KnnVectorQuery objects. */
+abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
+
+ abstract AbstractKnnVectorQuery getKnnVectorQuery(
+ String field, float[] query, int k, Query queryFilter);
+
+ abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery(
+ String field, float[] query, int k, Query queryFilter);
+
+ AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k) {
+ return getKnnVectorQuery(field, query, k, null);
+ }
+
+ abstract float[] randomVector(int dim);
+
+ abstract VectorEncoding vectorEncoding();
+
+ abstract Field getKnnVectorField(
+ String name, float[] vector, VectorSimilarityFunction similarityFunction);
+
+ abstract Field getKnnVectorField(String name, float[] vector);
public void testEquals() {
- KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10);
+ AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
Query filter1 = new TermQuery(new Term("id", "id1"));
- KnnVectorQuery q2 = new KnnVectorQuery("f1", new float[] {0, 1}, 10, filter1);
+ AbstractKnnVectorQuery q2 = getKnnVectorQuery("f1", new float[] {0, 1}, 10, filter1);
assertNotEquals(q2, q1);
assertNotEquals(q1, q2);
- assertEquals(q2, new KnnVectorQuery("f1", new float[] {0, 1}, 10, filter1));
+ assertEquals(q2, getKnnVectorQuery("f1", new float[] {0, 1}, 10, filter1));
Query filter2 = new TermQuery(new Term("id", "id2"));
- assertNotEquals(q2, new KnnVectorQuery("f1", new float[] {0, 1}, 10, filter2));
+ assertNotEquals(q2, getKnnVectorQuery("f1", new float[] {0, 1}, 10, filter2));
- assertEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 10));
+ assertEquals(q1, getKnnVectorQuery("f1", new float[] {0, 1}, 10));
assertNotEquals(null, q1);
assertNotEquals(q1, new TermQuery(new Term("f1", "x")));
- assertNotEquals(q1, new KnnVectorQuery("f2", new float[] {0, 1}, 10));
- assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {1, 1}, 10));
- assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 2));
- assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0}, 10));
- }
-
- public void testToString() {
- KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10);
- assertEquals("KnnVectorQuery:f1[0.0,...][10]", q1.toString("ignored"));
+ assertNotEquals(q1, getKnnVectorQuery("f2", new float[] {0, 1}, 10));
+ assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {1, 1}, 10));
+ assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {0, 1}, 2));
+ assertNotEquals(q1, getKnnVectorQuery("f1", new float[] {0}, 10));
}
/**
- * Tests if a KnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no documents to
- * match.
+ * Tests if a AbstractKnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no
+ * documents to match.
*/
public void testEmptyIndex() throws IOException {
try (Directory indexStore = getIndexStore("field");
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
- KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {1, 2}, 10);
+ AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {1, 2}, 10);
assertMatches(searcher, kvq, 0);
Query q = searcher.rewrite(kvq);
assertTrue(q instanceof MatchNoDocsQuery);
@@ -102,14 +111,15 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
/**
- * Tests that a KnnVectorQuery whose topK >= numDocs returns all the documents in score order
+ * Tests that a AbstractKnnVectorQuery whose topK >= numDocs returns all the documents in score
+ * order
*/
public void testFindAll() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
- KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10);
+ AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0, 0}, 10);
assertMatches(searcher, kvq, 3);
ScoreDoc[] scoreDocs = searcher.search(kvq, 3).scoreDocs;
assertIdMatches(reader, "id2", scoreDocs[0]);
@@ -124,7 +134,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
- Query vectorQuery = new KnnVectorQuery("field", new float[] {0, 0}, 10);
+ Query vectorQuery = getKnnVectorQuery("field", new float[] {0, 0}, 10);
ScoreDoc[] scoreDocs = searcher.search(vectorQuery, 3).scoreDocs;
Query boostQuery = new BoostQuery(vectorQuery, 3.0f);
@@ -141,14 +151,14 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
- /** Tests that a KnnVectorQuery applies the filter query */
+ /** Tests that a AbstractKnnVectorQuery applies the filter query */
public void testSimpleFilter() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("id", "id2"));
- Query kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10, filter);
+ Query kvq = getKnnVectorQuery("field", new float[] {0, 0}, 10, filter);
TopDocs topDocs = searcher.search(kvq, 3);
assertEquals(1, topDocs.totalHits.value);
assertIdMatches(reader, "id2", topDocs.scoreDocs[0]);
@@ -162,7 +172,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
IndexSearcher searcher = newSearcher(reader);
Query filter = new TermQuery(new Term("other", "value"));
- Query kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10, filter);
+ Query kvq = getKnnVectorQuery("field", new float[] {0, 0}, 10, filter);
TopDocs topDocs = searcher.search(kvq, 3);
assertEquals(0, topDocs.totalHits.value);
}
@@ -174,7 +184,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
- KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10);
+ AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
@@ -187,22 +197,21 @@ public class TestKnnVectorQuery extends LuceneTestCase {
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
- assertMatches(searcher, new KnnVectorQuery("xyzzy", new float[] {0}, 10), 0);
- assertMatches(searcher, new KnnVectorQuery("id", new float[] {0}, 10), 0);
+ assertMatches(searcher, getKnnVectorQuery("xyzzy", new float[] {0}, 10), 0);
+ assertMatches(searcher, getKnnVectorQuery("id", new float[] {0}, 10), 0);
}
}
/** Test bad parameters */
public void testIllegalArguments() throws IOException {
- expectThrows(
- IllegalArgumentException.class, () -> new KnnVectorQuery("xx", new float[] {1}, 0));
+ expectThrows(IllegalArgumentException.class, () -> getKnnVectorQuery("xx", new float[] {1}, 0));
}
public void testDifferentReader() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Query dasq = query.rewrite(reader);
IndexSearcher leafSearcher = newSearcher(reader.leaves().get(0).reader());
expectThrows(
@@ -216,13 +225,13 @@ public class TestKnnVectorQuery extends LuceneTestCase {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {j, j}));
+ doc.add(getKnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Query dasq = query.rewrite(reader);
Scorer scorer =
dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
@@ -249,7 +258,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
try (Directory d = getStableIndexStore("field", vectors);
IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Query rewritten = query.rewrite(reader);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
@@ -276,76 +285,19 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
- public void testScoreDotProduct() throws IOException {
- try (Directory d = newDirectory()) {
- try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
- for (int j = 1; j <= 5; j++) {
- Document doc = new Document();
- doc.add(
- new KnnVectorField(
- "field", VectorUtil.l2normalize(new float[] {j, j * j}), DOT_PRODUCT));
- w.addDocument(doc);
- }
- }
- try (IndexReader reader = DirectoryReader.open(d)) {
- assertEquals(1, reader.leaves().size());
- IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query =
- new KnnVectorQuery("field", VectorUtil.l2normalize(new float[] {2, 3}), 3);
- Query rewritten = query.rewrite(reader);
- Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
- Scorer scorer = weight.scorer(reader.leaves().get(0));
-
- // prior to advancing, score is undefined
- assertEquals(-1, scorer.docID());
- expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
-
- // test getMaxScore
- assertEquals(0, scorer.getMaxScore(-1), 0);
- /* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
- * normalized by (1 + x) /2.
- */
- float maxAtZero =
- (float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
- assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
-
- /* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
- * is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
- * normalized by (1 + x) /2
- */
- float expected =
- (float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
- assertEquals(expected, scorer.getMaxScore(2), 0);
- assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
-
- DocIdSetIterator it = scorer.iterator();
- assertEquals(3, it.cost());
- assertEquals(0, it.nextDoc());
- // doc 0 has (1, 1)
- assertEquals(maxAtZero, scorer.score(), 0.0001);
- assertEquals(1, it.advance(1));
- assertEquals(expected, scorer.score(), 0);
- assertEquals(2, it.nextDoc());
- // since topK was 3
- assertEquals(NO_MORE_DOCS, it.advance(4));
- expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
- }
- }
- }
-
public void testScoreCosine() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 1; j <= 5; j++) {
Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {j, j * j}, COSINE));
+ doc.add(getKnnVectorField("field", new float[] {j, j * j}, COSINE));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Query rewritten = query.rewrite(reader);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
@@ -387,47 +339,18 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
- public void testScoreNegativeDotProduct() throws IOException {
- try (Directory d = newDirectory()) {
- try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
- Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {-1, 0}, DOT_PRODUCT));
- w.addDocument(doc);
- doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {1, 0}, DOT_PRODUCT));
- w.addDocument(doc);
- }
- try (IndexReader reader = DirectoryReader.open(d)) {
- assertEquals(1, reader.leaves().size());
- IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {1, 0}, 2);
- Query rewritten = query.rewrite(reader);
- Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
- Scorer scorer = weight.scorer(reader.leaves().get(0));
-
- // scores are normalized to lie in [0, 1]
- DocIdSetIterator it = scorer.iterator();
- assertEquals(2, it.cost());
- assertEquals(0, it.nextDoc());
- assertEquals(0, scorer.score(), 0);
- assertEquals(1, it.advance(1));
- assertEquals(1, scorer.score(), 0);
- }
- }
- }
-
public void testExplain() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {j, j}));
+ doc.add(getKnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Explanation matched = searcher.explain(query, 2);
assertTrue(matched.isMatch());
assertEquals(1 / 2f, matched.getValue());
@@ -448,14 +371,14 @@ public class TestKnnVectorQuery extends LuceneTestCase {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {j, j}));
+ doc.add(getKnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
w.commit();
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Explanation matched = searcher.explain(query, 2);
assertTrue(matched.isMatch());
assertEquals(1 / 2f, matched.getValue());
@@ -483,7 +406,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
for (int i = 0; i < 5; i++) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {r, r}));
+ doc.add(getKnnVectorField("field", new float[] {r, r}));
doc.add(new StringField("id", "id" + r, Field.Store.YES));
w.addDocument(doc);
++r;
@@ -493,13 +416,13 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
- TopDocs results = searcher.search(new KnnVectorQuery("field", new float[] {0, 0}, 8), 10);
+ TopDocs results = searcher.search(getKnnVectorQuery("field", new float[] {0, 0}, 8), 10);
assertEquals(8, results.scoreDocs.length);
assertIdMatches(reader, "id0", results.scoreDocs[0]);
assertIdMatches(reader, "id7", results.scoreDocs[7]);
// test some results in the middle of the sequence - also tests docid tiebreaking
- results = searcher.search(new KnnVectorQuery("field", new float[] {10, 10}, 8), 10);
+ results = searcher.search(getKnnVectorQuery("field", new float[] {10, 10}, 8), 10);
assertEquals(8, results.scoreDocs.length);
assertIdMatches(reader, "id10", results.scoreDocs[0]);
assertIdMatches(reader, "id6", results.scoreDocs[7]);
@@ -518,7 +441,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
if (everyDocHasAVector || random().nextInt(10) != 2) {
- doc.add(new KnnVectorField("field", randomVector(dimension)));
+ doc.add(getKnnVectorField("field", randomVector(dimension)));
}
w.addDocument(doc);
}
@@ -527,7 +450,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < numIters; i++) {
int k = random().nextInt(80) + 1;
- KnnVectorQuery query = new KnnVectorQuery("field", randomVector(dimension), k);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("field", randomVector(dimension), k);
int n = random().nextInt(100) + 1;
TopDocs results = searcher.search(query, n);
int expected = Math.min(Math.min(n, k), reader.numDocs());
@@ -554,13 +477,14 @@ public class TestKnnVectorQuery extends LuceneTestCase {
int numIters = atLeast(10);
try (Directory d = newDirectory()) {
// Always use the default kNN format to have predictable behavior around when it hits
- // visitedLimit. This is fine since the test targets KnnVectorQuery logic, not the kNN format
+ // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
+ // format
// implementation.
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
- doc.add(new KnnVectorField("field", randomVector(dimension)));
+ doc.add(getKnnVectorField("field", randomVector(dimension)));
doc.add(new NumericDocValuesField("tag", i));
doc.add(new IntPoint("tag", i));
w.addDocument(doc);
@@ -577,35 +501,35 @@ public class TestKnnVectorQuery extends LuceneTestCase {
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 8);
TopDocs results =
searcher.search(
- new KnnVectorQuery("field", randomVector(dimension), 10, filter1), numDocs);
+ getKnnVectorQuery("field", randomVector(dimension), 10, filter1), numDocs);
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
UnsupportedOperationException.class,
() ->
searcher.search(
- new ThrowingKnnVectorQuery("field", randomVector(dimension), 10, filter1),
+ getThrowingKnnVectorQuery("field", randomVector(dimension), 10, filter1),
numDocs));
// Test a restrictive filter and check we use exact search
Query filter2 = IntPoint.newRangeQuery("tag", lower, lower + 6);
results =
searcher.search(
- new KnnVectorQuery("field", randomVector(dimension), 5, filter2), numDocs);
+ getKnnVectorQuery("field", randomVector(dimension), 5, filter2), numDocs);
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
UnsupportedOperationException.class,
() ->
searcher.search(
- new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter2),
+ getThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter2),
numDocs));
// Test an unrestrictive filter and check we use approximate search
Query filter3 = IntPoint.newRangeQuery("tag", lower, numDocs);
results =
searcher.search(
- new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter3),
+ getThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter3),
numDocs,
new Sort(new SortField("tag", SortField.Type.INT)));
assertEquals(5, results.totalHits.value);
@@ -625,7 +549,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
UnsupportedOperationException.class,
() ->
searcher.search(
- new ThrowingKnnVectorQuery("field", randomVector(dimension), 1, filter4),
+ getThrowingKnnVectorQuery("field", randomVector(dimension), 1, filter4),
numDocs));
}
}
@@ -639,14 +563,15 @@ public class TestKnnVectorQuery extends LuceneTestCase {
int dimension = atLeast(5);
try (Directory d = newDirectory()) {
// Always use the default kNN format to have predictable behavior around when it hits
- // visitedLimit. This is fine since the test targets KnnVectorQuery logic, not the kNN format
+ // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
+ // format
// implementation.
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
IndexWriter w = new IndexWriter(d, iwc);
float[] vector = randomVector(dimension);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
- doc.add(new KnnVectorField("field", vector));
+ doc.add(getKnnVectorField("field", vector));
doc.add(new IntPoint("tag", i));
w.addDocument(doc);
}
@@ -662,14 +587,14 @@ public class TestKnnVectorQuery extends LuceneTestCase {
Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 6);
TopDocs results =
searcher.search(
- new KnnVectorQuery("field", randomVector(dimension), size, filter1), size);
+ getKnnVectorQuery("field", randomVector(dimension), size, filter1), size);
assertEquals(size, results.scoreDocs.length);
// Test an unrestrictive filter, which usually performs approximate search
Query filter2 = IntPoint.newRangeQuery("tag", lower, numDocs);
results =
searcher.search(
- new KnnVectorQuery("field", randomVector(dimension), size, filter2), size);
+ getKnnVectorQuery("field", randomVector(dimension), size, filter2), size);
assertEquals(size, results.scoreDocs.length);
}
}
@@ -684,7 +609,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
Document d = new Document();
d.add(new StringField("index", String.valueOf(i), Field.Store.YES));
if (frequently()) {
- d.add(new KnnVectorField("vector", randomVector(dim)));
+ d.add(getKnnVectorField("vector", randomVector(dim)));
}
w.addDocument(d);
}
@@ -703,7 +628,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
try (IndexReader reader = DirectoryReader.open(dir)) {
Set<String> allIds = new HashSet<>();
IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), hits);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("vector", randomVector(dim), hits);
TopDocs topDocs = searcher.search(query, numDocs);
StoredFields storedFields = reader.storedFields();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
@@ -726,7 +651,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
final int dim = 30;
for (int i = 0; i < numDocs; ++i) {
Document d = new Document();
- d.add(new KnnVectorField("vector", randomVector(dim)));
+ d.add(getKnnVectorField("vector", randomVector(dim)));
w.addDocument(d);
}
w.commit();
@@ -736,7 +661,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
try (IndexReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("vector", randomVector(dim), numDocs);
TopDocs topDocs = searcher.search(query, numDocs);
assertEquals(0, topDocs.scoreDocs.length);
}
@@ -756,7 +681,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
for (int i = 0; i < numDocs; ++i) {
Document d = new Document();
d.add(new StringField("index", String.valueOf(i), Field.Store.NO));
- d.add(new KnnVectorField("vector", randomVector(dim)));
+ d.add(getKnnVectorField("vector", randomVector(dim)));
w.addDocument(d);
}
w.commit();
@@ -764,7 +689,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
try (DirectoryReader reader = DirectoryReader.open(dir)) {
DirectoryReader wrappedReader = new NoLiveDocsDirectoryReader(reader);
IndexSearcher searcher = new IndexSearcher(wrappedReader);
- KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("vector", randomVector(dim), numDocs);
TopDocs topDocs = searcher.search(query, numDocs);
assertEquals(0, topDocs.scoreDocs.length);
}
@@ -772,7 +697,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
/**
- * Test that KnnVectorQuery optimizes the case where the filter query is backed by {@link
+ * Test that AbstractKnnVectorQuery optimizes the case where the filter query is backed by {@link
* BitSetIterator}.
*/
public void testBitSetQuery() throws IOException {
@@ -783,7 +708,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
final int dim = 30;
for (int i = 0; i < numDocs; ++i) {
Document d = new Document();
- d.add(new KnnVectorField("vector", randomVector(dim)));
+ d.add(getKnnVectorField("vector", randomVector(dim)));
w.addDocument(d);
}
w.commit();
@@ -796,27 +721,18 @@ public class TestKnnVectorQuery extends LuceneTestCase {
UnsupportedOperationException.class,
() ->
searcher.search(
- new KnnVectorQuery("vector", randomVector(dim), 10, filter), numDocs));
+ getKnnVectorQuery("vector", randomVector(dim), 10, filter), numDocs));
}
}
}
/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
- private Directory getIndexStore(String field, float[]... contents) throws IOException {
+ Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
- VectorEncoding encoding = randomVectorEncoding();
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
- if (encoding == VectorEncoding.BYTE) {
- BytesRef v = new BytesRef(new byte[contents[i].length]);
- for (int j = 0; j < v.length; j++) {
- v.bytes[j] = (byte) contents[i][j];
- }
- doc.add(new KnnVectorField(field, v, EUCLIDEAN));
- } else {
- doc.add(new KnnVectorField(field, contents[i]));
- }
+ doc.add(getKnnVectorField(field, contents[i]));
doc.add(new StringField("id", "id" + i, Field.Store.YES));
writer.addDocument(doc);
}
@@ -837,18 +753,9 @@ public class TestKnnVectorQuery extends LuceneTestCase {
private Directory getStableIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
try (IndexWriter writer = new IndexWriter(indexStore, new IndexWriterConfig())) {
- VectorEncoding encoding = randomVectorEncoding();
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
- if (encoding == VectorEncoding.BYTE) {
- BytesRef v = new BytesRef(new byte[contents[i].length]);
- for (int j = 0; j < v.length; j++) {
- v.bytes[j] = (byte) contents[i][j];
- }
- doc.add(new KnnVectorField(field, v, EUCLIDEAN));
- } else {
- doc.add(new KnnVectorField(field, contents[i]));
- }
+ doc.add(getKnnVectorField(field, contents[i]));
doc.add(new StringField("id", "id" + i, Field.Store.YES));
writer.addDocument(doc);
}
@@ -868,28 +775,16 @@ public class TestKnnVectorQuery extends LuceneTestCase {
assertEquals(expectedMatches, result.length);
}
- private void assertIdMatches(IndexReader reader, String expectedId, ScoreDoc scoreDoc)
+ void assertIdMatches(IndexReader reader, String expectedId, ScoreDoc scoreDoc)
throws IOException {
String actualId = reader.storedFields().document(scoreDoc.doc).get("id");
assertEquals(expectedId, actualId);
}
/**
- * A version of {@link KnnVectorQuery} that throws an error when an exact search is run. This
- * allows us to check what search strategy is being used.
+ * A version of {@link AbstractKnnVectorQuery} that throws an error when an exact search is run.
+ * This allows us to check what search strategy is being used.
*/
- private static class ThrowingKnnVectorQuery extends KnnVectorQuery {
-
- public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) {
- super(field, target, k, filter);
- }
-
- @Override
- protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
- throw new UnsupportedOperationException("exact search is not supported");
- }
- }
-
private static class NoLiveDocsDirectoryReader extends FilterDirectoryReader {
private NoLiveDocsDirectoryReader(DirectoryReader in) throws IOException {
@@ -940,7 +835,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
- private static class ThrowingBitSetQuery extends Query {
+ static class ThrowingBitSetQuery extends Query {
private final FixedBitSet docs;
@@ -989,8 +884,4 @@ public class TestKnnVectorQuery extends LuceneTestCase {
return 31 * classHash() + docs.hashCode();
}
}
-
- private VectorEncoding randomVectorEncoding() {
- return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
- }
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
new file mode 100644
index 00000000000..f5037af9177
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.TestVectorUtil;
+
+public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
+ @Override
+ AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
+ return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
+ }
+
+ @Override
+ AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
+ return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query);
+ }
+
+ @Override
+ float[] randomVector(int dim) {
+ BytesRef bytesRef = TestVectorUtil.randomVectorBytes(dim);
+ float[] v = new float[bytesRef.length];
+ int vi = 0;
+ for (int i = bytesRef.offset; i < v.length; i++) {
+ v[vi++] = bytesRef.bytes[i];
+ }
+ return v;
+ }
+
+ @Override
+ Field getKnnVectorField(
+ String name, float[] vector, VectorSimilarityFunction similarityFunction) {
+ return new KnnVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
+ }
+
+ @Override
+ Field getKnnVectorField(String name, float[] vector) {
+ return new KnnVectorField(
+ name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN);
+ }
+
+ private static byte[] floatToBytes(float[] query) {
+ byte[] bytes = new byte[query.length];
+ for (int i = 0; i < query.length; i++) {
+ assert query[i] <= Byte.MAX_VALUE && query[i] >= Byte.MIN_VALUE && (query[i] % 1) == 0
+ : "float value cannot be converted to byte; provided: " + query[i];
+ bytes[i] = (byte) query[i];
+ }
+ return bytes;
+ }
+
+ public void testToString() {
+ AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
+ assertEquals("KnnByteVectorQuery:f1[0,...][10]", q1.toString("ignored"));
+ }
+
+ @Override
+ VectorEncoding vectorEncoding() {
+ return VectorEncoding.BYTE;
+ }
+
+ private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
+
+ public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
+ super(field, target, k, filter);
+ }
+
+ @Override
+ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
+ throw new UnsupportedOperationException("exact search is not supported");
+ }
+
+ @Override
+ public String toString(String field) {
+ return null;
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
index ba8fb8c91e3..c50b6b864d9 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
@@ -16,336 +16,106 @@
*/
package org.apache.lucene.search;
-import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
-import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
-import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
-import static org.apache.lucene.util.TestVectorUtil.randomVector;
import java.io.IOException;
-import java.util.HashSet;
-import java.util.Set;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
-import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.KnnVectorField;
-import org.apache.lucene.document.NumericDocValuesField;
-import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
-import org.apache.lucene.index.FilterDirectoryReader;
-import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
-import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
-import org.apache.lucene.index.StoredFields;
-import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
-import org.apache.lucene.tests.index.RandomIndexWriter;
-import org.apache.lucene.tests.util.LuceneTestCase;
-import org.apache.lucene.tests.util.TestUtil;
-import org.apache.lucene.util.BitSet;
-import org.apache.lucene.util.BitSetIterator;
-import org.apache.lucene.util.Bits;
-import org.apache.lucene.util.BytesRef;
-import org.apache.lucene.util.FixedBitSet;
+import org.apache.lucene.util.TestVectorUtil;
import org.apache.lucene.util.VectorUtil;
-/** TestKnnVectorQuery tests KnnVectorQuery. */
-public class TestKnnVectorQuery extends LuceneTestCase {
-
- public void testEquals() {
- KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10);
- Query filter1 = new TermQuery(new Term("id", "id1"));
- KnnVectorQuery q2 = new KnnVectorQuery("f1", new float[] {0, 1}, 10, filter1);
-
- assertNotEquals(q2, q1);
- assertNotEquals(q1, q2);
- assertEquals(q2, new KnnVectorQuery("f1", new float[] {0, 1}, 10, filter1));
-
- Query filter2 = new TermQuery(new Term("id", "id2"));
- assertNotEquals(q2, new KnnVectorQuery("f1", new float[] {0, 1}, 10, filter2));
-
- assertEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 10));
-
- assertNotEquals(null, q1);
-
- assertNotEquals(q1, new TermQuery(new Term("f1", "x")));
-
- assertNotEquals(q1, new KnnVectorQuery("f2", new float[] {0, 1}, 10));
- assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {1, 1}, 10));
- assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 2));
- assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0}, 10));
- }
-
- public void testToString() {
- KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10);
- assertEquals("KnnVectorQuery:f1[0.0,...][10]", q1.toString("ignored"));
- }
-
- /**
- * Tests if a KnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no documents to
- * match.
- */
- public void testEmptyIndex() throws IOException {
- try (Directory indexStore = getIndexStore("field");
- IndexReader reader = DirectoryReader.open(indexStore)) {
- IndexSearcher searcher = newSearcher(reader);
- KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {1, 2}, 10);
- assertMatches(searcher, kvq, 0);
- Query q = searcher.rewrite(kvq);
- assertTrue(q instanceof MatchNoDocsQuery);
- }
- }
-
- /**
- * Tests that a KnnVectorQuery whose topK >= numDocs returns all the documents in score order
- */
- public void testFindAll() throws IOException {
- try (Directory indexStore =
- getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
- IndexReader reader = DirectoryReader.open(indexStore)) {
- IndexSearcher searcher = newSearcher(reader);
- KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10);
- assertMatches(searcher, kvq, 3);
- ScoreDoc[] scoreDocs = searcher.search(kvq, 3).scoreDocs;
- assertIdMatches(reader, "id2", scoreDocs[0]);
- assertIdMatches(reader, "id0", scoreDocs[1]);
- assertIdMatches(reader, "id1", scoreDocs[2]);
- }
- }
-
- public void testSearchBoost() throws IOException {
- try (Directory indexStore =
- getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
- IndexReader reader = DirectoryReader.open(indexStore)) {
- IndexSearcher searcher = newSearcher(reader);
-
- Query vectorQuery = new KnnVectorQuery("field", new float[] {0, 0}, 10);
- ScoreDoc[] scoreDocs = searcher.search(vectorQuery, 3).scoreDocs;
-
- Query boostQuery = new BoostQuery(vectorQuery, 3.0f);
- ScoreDoc[] boostScoreDocs = searcher.search(boostQuery, 3).scoreDocs;
- assertEquals(scoreDocs.length, boostScoreDocs.length);
-
- for (int i = 0; i < scoreDocs.length; i++) {
- ScoreDoc scoreDoc = scoreDocs[i];
- ScoreDoc boostScoreDoc = boostScoreDocs[i];
-
- assertEquals(scoreDoc.doc, boostScoreDoc.doc);
- assertEquals(scoreDoc.score * 3.0f, boostScoreDoc.score, 0.001f);
- }
- }
+public class TestKnnVectorQuery extends BaseKnnVectorQueryTestCase {
+ @Override
+ KnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
+ return new KnnVectorQuery(field, query, k, queryFilter);
}
- /** Tests that a KnnVectorQuery applies the filter query */
- public void testSimpleFilter() throws IOException {
- try (Directory indexStore =
- getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
- IndexReader reader = DirectoryReader.open(indexStore)) {
- IndexSearcher searcher = newSearcher(reader);
- Query filter = new TermQuery(new Term("id", "id2"));
- Query kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10, filter);
- TopDocs topDocs = searcher.search(kvq, 3);
- assertEquals(1, topDocs.totalHits.value);
- assertIdMatches(reader, "id2", topDocs.scoreDocs[0]);
- }
+ @Override
+ AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
+ return new ThrowingKnnVectorQuery(field, vec, k, query);
}
- public void testFilterWithNoVectorMatches() throws IOException {
- try (Directory indexStore =
- getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
- IndexReader reader = DirectoryReader.open(indexStore)) {
- IndexSearcher searcher = newSearcher(reader);
-
- Query filter = new TermQuery(new Term("other", "value"));
- Query kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10, filter);
- TopDocs topDocs = searcher.search(kvq, 3);
- assertEquals(0, topDocs.totalHits.value);
- }
+ @Override
+ float[] randomVector(int dim) {
+ return TestVectorUtil.randomVector(dim);
}
- /** testDimensionMismatch */
- public void testDimensionMismatch() throws IOException {
- try (Directory indexStore =
- getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
- IndexReader reader = DirectoryReader.open(indexStore)) {
- IndexSearcher searcher = newSearcher(reader);
- KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10);
- IllegalArgumentException e =
- expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
- assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
- }
+ @Override
+ Field getKnnVectorField(
+ String name, float[] vector, VectorSimilarityFunction similarityFunction) {
+ return new KnnVectorField(name, vector, similarityFunction);
}
- /** testNonVectorField */
- public void testNonVectorField() throws IOException {
- try (Directory indexStore =
- getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
- IndexReader reader = DirectoryReader.open(indexStore)) {
- IndexSearcher searcher = newSearcher(reader);
- assertMatches(searcher, new KnnVectorQuery("xyzzy", new float[] {0}, 10), 0);
- assertMatches(searcher, new KnnVectorQuery("id", new float[] {0}, 10), 0);
- }
- }
-
- /** Test bad parameters */
- public void testIllegalArguments() throws IOException {
- expectThrows(
- IllegalArgumentException.class, () -> new KnnVectorQuery("xx", new float[] {1}, 0));
+ @Override
+ Field getKnnVectorField(String name, float[] vector) {
+ return new KnnVectorField(name, vector);
}
- public void testDifferentReader() throws IOException {
- try (Directory indexStore =
- getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
- IndexReader reader = DirectoryReader.open(indexStore)) {
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
- Query dasq = query.rewrite(reader);
- IndexSearcher leafSearcher = newSearcher(reader.leaves().get(0).reader());
- expectThrows(
- IllegalStateException.class,
- () -> dasq.createWeight(leafSearcher, ScoreMode.COMPLETE, 1));
- }
- }
-
- public void testAdvanceShallow() throws IOException {
- try (Directory d = newDirectory()) {
- try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
- for (int j = 0; j < 5; j++) {
- Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {j, j}));
- w.addDocument(doc);
- }
- }
- try (IndexReader reader = DirectoryReader.open(d)) {
- IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
- Query dasq = query.rewrite(reader);
- Scorer scorer =
- dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
- // before advancing the iterator
- assertEquals(1, scorer.advanceShallow(0));
- assertEquals(1, scorer.advanceShallow(1));
- assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
-
- // after advancing the iterator
- scorer.iterator().advance(2);
- assertEquals(2, scorer.advanceShallow(0));
- assertEquals(2, scorer.advanceShallow(2));
- assertEquals(3, scorer.advanceShallow(3));
- assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
- }
- }
+ public void testToString() {
+ AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
+ assertEquals("KnnVectorQuery:f1[0.0,...][10]", q1.toString("ignored"));
}
- public void testScoreEuclidean() throws IOException {
- float[][] vectors = new float[5][];
- for (int j = 0; j < 5; j++) {
- vectors[j] = new float[] {j, j};
- }
- try (Directory d = getStableIndexStore("field", vectors);
- IndexReader reader = DirectoryReader.open(d)) {
- IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
- Query rewritten = query.rewrite(reader);
- Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
- Scorer scorer = weight.scorer(reader.leaves().get(0));
-
- // prior to advancing, score is 0
- assertEquals(-1, scorer.docID());
- expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
-
- // test getMaxScore
- assertEquals(0, scorer.getMaxScore(-1), 0);
- assertEquals(0, scorer.getMaxScore(0), 0);
- // This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
- assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
- assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
-
- DocIdSetIterator it = scorer.iterator();
- assertEquals(3, it.cost());
- assertEquals(1, it.nextDoc());
- assertEquals(1 / 6f, scorer.score(), 0);
- assertEquals(3, it.advance(3));
- assertEquals(1 / 2f, scorer.score(), 0);
- assertEquals(NO_MORE_DOCS, it.advance(4));
- expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
- }
+ @Override
+ VectorEncoding vectorEncoding() {
+ return VectorEncoding.FLOAT32;
}
- public void testScoreDotProduct() throws IOException {
+ public void testScoreNegativeDotProduct() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
- for (int j = 1; j <= 5; j++) {
- Document doc = new Document();
- doc.add(
- new KnnVectorField(
- "field", VectorUtil.l2normalize(new float[] {j, j * j}), DOT_PRODUCT));
- w.addDocument(doc);
- }
+ Document doc = new Document();
+ doc.add(getKnnVectorField("field", new float[] {-1, 0}, DOT_PRODUCT));
+ w.addDocument(doc);
+ doc = new Document();
+ doc.add(getKnnVectorField("field", new float[] {1, 0}, DOT_PRODUCT));
+ w.addDocument(doc);
}
try (IndexReader reader = DirectoryReader.open(d)) {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query =
- new KnnVectorQuery("field", VectorUtil.l2normalize(new float[] {2, 3}), 3);
+ AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {1, 0}, 2);
Query rewritten = query.rewrite(reader);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
- // prior to advancing, score is undefined
- assertEquals(-1, scorer.docID());
- expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
-
- // test getMaxScore
- assertEquals(0, scorer.getMaxScore(-1), 0);
- /* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
- * normalized by (1 + x) /2.
- */
- float maxAtZero =
- (float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
- assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
-
- /* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
- * is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
- * normalized by (1 + x) /2
- */
- float expected =
- (float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
- assertEquals(expected, scorer.getMaxScore(2), 0);
- assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
-
+ // scores are normalized to lie in [0, 1]
DocIdSetIterator it = scorer.iterator();
- assertEquals(3, it.cost());
+ assertEquals(2, it.cost());
assertEquals(0, it.nextDoc());
- // doc 0 has (1, 1)
- assertEquals(maxAtZero, scorer.score(), 0.0001);
+ assertEquals(0, scorer.score(), 0);
assertEquals(1, it.advance(1));
- assertEquals(expected, scorer.score(), 0);
- assertEquals(2, it.nextDoc());
- // since topK was 3
- assertEquals(NO_MORE_DOCS, it.advance(4));
- expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
+ assertEquals(1, scorer.score(), 0);
}
}
}
- public void testScoreCosine() throws IOException {
+ public void testScoreDotProduct() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 1; j <= 5; j++) {
Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {j, j * j}, COSINE));
+ doc.add(
+ getKnnVectorField(
+ "field", VectorUtil.l2normalize(new float[] {j, j * j}), DOT_PRODUCT));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ AbstractKnnVectorQuery query =
+ getKnnVectorQuery("field", VectorUtil.l2normalize(new float[] {2, 3}), 3);
Query rewritten = query.rewrite(reader);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
@@ -387,497 +157,6 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
- public void testScoreNegativeDotProduct() throws IOException {
- try (Directory d = newDirectory()) {
- try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
- Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {-1, 0}, DOT_PRODUCT));
- w.addDocument(doc);
- doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {1, 0}, DOT_PRODUCT));
- w.addDocument(doc);
- }
- try (IndexReader reader = DirectoryReader.open(d)) {
- assertEquals(1, reader.leaves().size());
- IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {1, 0}, 2);
- Query rewritten = query.rewrite(reader);
- Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
- Scorer scorer = weight.scorer(reader.leaves().get(0));
-
- // scores are normalized to lie in [0, 1]
- DocIdSetIterator it = scorer.iterator();
- assertEquals(2, it.cost());
- assertEquals(0, it.nextDoc());
- assertEquals(0, scorer.score(), 0);
- assertEquals(1, it.advance(1));
- assertEquals(1, scorer.score(), 0);
- }
- }
- }
-
- public void testExplain() throws IOException {
- try (Directory d = newDirectory()) {
- try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
- for (int j = 0; j < 5; j++) {
- Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {j, j}));
- w.addDocument(doc);
- }
- }
- try (IndexReader reader = DirectoryReader.open(d)) {
- IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
- Explanation matched = searcher.explain(query, 2);
- assertTrue(matched.isMatch());
- assertEquals(1 / 2f, matched.getValue());
- assertEquals(0, matched.getDetails().length);
- assertEquals("within top 3", matched.getDescription());
-
- Explanation nomatch = searcher.explain(query, 4);
- assertFalse(nomatch.isMatch());
- assertEquals(0f, nomatch.getValue());
- assertEquals(0, matched.getDetails().length);
- assertEquals("not in top 3", nomatch.getDescription());
- }
- }
- }
-
- public void testExplainMultipleSegments() throws IOException {
- try (Directory d = newDirectory()) {
- try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
- for (int j = 0; j < 5; j++) {
- Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {j, j}));
- w.addDocument(doc);
- w.commit();
- }
- }
- try (IndexReader reader = DirectoryReader.open(d)) {
- IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
- Explanation matched = searcher.explain(query, 2);
- assertTrue(matched.isMatch());
- assertEquals(1 / 2f, matched.getValue());
- assertEquals(0, matched.getDetails().length);
- assertEquals("within top 3", matched.getDescription());
-
- Explanation nomatch = searcher.explain(query, 4);
- assertFalse(nomatch.isMatch());
- assertEquals(0f, nomatch.getValue());
- assertEquals(0, matched.getDetails().length);
- assertEquals("not in top 3", nomatch.getDescription());
- }
- }
- }
-
- /** Test that when vectors are abnormally distributed among segments, we still find the top K */
- public void testSkewedIndex() throws IOException {
- /* We have to choose the numbers carefully here so that some segment has more than the expected
- * number of top K documents, but no more than K documents in total (otherwise we might occasionally
- * randomly fail to find one).
- */
- try (Directory d = newDirectory()) {
- try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
- int r = 0;
- for (int i = 0; i < 5; i++) {
- for (int j = 0; j < 5; j++) {
- Document doc = new Document();
- doc.add(new KnnVectorField("field", new float[] {r, r}));
- doc.add(new StringField("id", "id" + r, Field.Store.YES));
- w.addDocument(doc);
- ++r;
- }
- w.flush();
- }
- }
- try (IndexReader reader = DirectoryReader.open(d)) {
- IndexSearcher searcher = newSearcher(reader);
- TopDocs results = searcher.search(new KnnVectorQuery("field", new float[] {0, 0}, 8), 10);
- assertEquals(8, results.scoreDocs.length);
- assertIdMatches(reader, "id0", results.scoreDocs[0]);
- assertIdMatches(reader, "id7", results.scoreDocs[7]);
-
- // test some results in the middle of the sequence - also tests docid tiebreaking
- results = searcher.search(new KnnVectorQuery("field", new float[] {10, 10}, 8), 10);
- assertEquals(8, results.scoreDocs.length);
- assertIdMatches(reader, "id10", results.scoreDocs[0]);
- assertIdMatches(reader, "id6", results.scoreDocs[7]);
- }
- }
- }
-
- /** Tests with random vectors, number of documents, etc. Uses RandomIndexWriter. */
- public void testRandom() throws IOException {
- int numDocs = atLeast(100);
- int dimension = atLeast(5);
- int numIters = atLeast(10);
- boolean everyDocHasAVector = random().nextBoolean();
- try (Directory d = newDirectory()) {
- RandomIndexWriter w = new RandomIndexWriter(random(), d);
- for (int i = 0; i < numDocs; i++) {
- Document doc = new Document();
- if (everyDocHasAVector || random().nextInt(10) != 2) {
- doc.add(new KnnVectorField("field", randomVector(dimension)));
- }
- w.addDocument(doc);
- }
- w.close();
- try (IndexReader reader = DirectoryReader.open(d)) {
- IndexSearcher searcher = newSearcher(reader);
- for (int i = 0; i < numIters; i++) {
- int k = random().nextInt(80) + 1;
- KnnVectorQuery query = new KnnVectorQuery("field", randomVector(dimension), k);
- int n = random().nextInt(100) + 1;
- TopDocs results = searcher.search(query, n);
- int expected = Math.min(Math.min(n, k), reader.numDocs());
- // we may get fewer results than requested if there are deletions, but this test doesn't
- // test that
- assert reader.hasDeletions() == false;
- assertEquals(expected, results.scoreDocs.length);
- assertTrue(results.totalHits.value >= results.scoreDocs.length);
- // verify the results are in descending score order
- float last = Float.MAX_VALUE;
- for (ScoreDoc scoreDoc : results.scoreDocs) {
- assertTrue(scoreDoc.score <= last);
- last = scoreDoc.score;
- }
- }
- }
- }
- }
-
- /** Tests with random vectors and a random filter. Uses RandomIndexWriter. */
- public void testRandomWithFilter() throws IOException {
- int numDocs = 1000;
- int dimension = atLeast(5);
- int numIters = atLeast(10);
- try (Directory d = newDirectory()) {
- // Always use the default kNN format to have predictable behavior around when it hits
- // visitedLimit. This is fine since the test targets KnnVectorQuery logic, not the kNN format
- // implementation.
- IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
- RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
- for (int i = 0; i < numDocs; i++) {
- Document doc = new Document();
- doc.add(new KnnVectorField("field", randomVector(dimension)));
- doc.add(new NumericDocValuesField("tag", i));
- doc.add(new IntPoint("tag", i));
- w.addDocument(doc);
- }
- w.forceMerge(1);
- w.close();
-
- try (DirectoryReader reader = DirectoryReader.open(d)) {
- IndexSearcher searcher = newSearcher(reader);
- for (int i = 0; i < numIters; i++) {
- int lower = random().nextInt(500);
-
- // Test a filter with cost less than k and check we use exact search
- Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 8);
- TopDocs results =
- searcher.search(
- new KnnVectorQuery("field", randomVector(dimension), 10, filter1), numDocs);
- assertEquals(9, results.totalHits.value);
- assertEquals(results.totalHits.value, results.scoreDocs.length);
- expectThrows(
- UnsupportedOperationException.class,
- () ->
- searcher.search(
- new ThrowingKnnVectorQuery("field", randomVector(dimension), 10, filter1),
- numDocs));
-
- // Test a restrictive filter and check we use exact search
- Query filter2 = IntPoint.newRangeQuery("tag", lower, lower + 6);
- results =
- searcher.search(
- new KnnVectorQuery("field", randomVector(dimension), 5, filter2), numDocs);
- assertEquals(5, results.totalHits.value);
- assertEquals(results.totalHits.value, results.scoreDocs.length);
- expectThrows(
- UnsupportedOperationException.class,
- () ->
- searcher.search(
- new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter2),
- numDocs));
-
- // Test an unrestrictive filter and check we use approximate search
- Query filter3 = IntPoint.newRangeQuery("tag", lower, numDocs);
- results =
- searcher.search(
- new ThrowingKnnVectorQuery("field", randomVector(dimension), 5, filter3),
- numDocs,
- new Sort(new SortField("tag", SortField.Type.INT)));
- assertEquals(5, results.totalHits.value);
- assertEquals(results.totalHits.value, results.scoreDocs.length);
-
- for (ScoreDoc scoreDoc : results.scoreDocs) {
- FieldDoc fieldDoc = (FieldDoc) scoreDoc;
- assertEquals(1, fieldDoc.fields.length);
-
- int tag = (int) fieldDoc.fields[0];
- assertTrue(lower <= tag && tag <= numDocs);
- }
-
- // Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
- Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
- expectThrows(
- UnsupportedOperationException.class,
- () ->
- searcher.search(
- new ThrowingKnnVectorQuery("field", randomVector(dimension), 1, filter4),
- numDocs));
- }
- }
- }
- }
-
- /** Tests filtering when all vectors have the same score. */
- @AwaitsFix(bugUrl = "https://github.com/apache/lucene/issues/11787")
- public void testFilterWithSameScore() throws IOException {
- int numDocs = 100;
- int dimension = atLeast(5);
- try (Directory d = newDirectory()) {
- // Always use the default kNN format to have predictable behavior around when it hits
- // visitedLimit. This is fine since the test targets KnnVectorQuery logic, not the kNN format
- // implementation.
- IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
- IndexWriter w = new IndexWriter(d, iwc);
- float[] vector = randomVector(dimension);
- for (int i = 0; i < numDocs; i++) {
- Document doc = new Document();
- doc.add(new KnnVectorField("field", vector));
- doc.add(new IntPoint("tag", i));
- w.addDocument(doc);
- }
- w.forceMerge(1);
- w.close();
-
- try (DirectoryReader reader = DirectoryReader.open(d)) {
- IndexSearcher searcher = newSearcher(reader);
- int lower = random().nextInt(50);
- int size = 5;
-
- // Test a restrictive filter, which usually performs exact search
- Query filter1 = IntPoint.newRangeQuery("tag", lower, lower + 6);
- TopDocs results =
- searcher.search(
- new KnnVectorQuery("field", randomVector(dimension), size, filter1), size);
- assertEquals(size, results.scoreDocs.length);
-
- // Test an unrestrictive filter, which usually performs approximate search
- Query filter2 = IntPoint.newRangeQuery("tag", lower, numDocs);
- results =
- searcher.search(
- new KnnVectorQuery("field", randomVector(dimension), size, filter2), size);
- assertEquals(size, results.scoreDocs.length);
- }
- }
- }
-
- public void testDeletes() throws IOException {
- try (Directory dir = newDirectory();
- IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
- final int numDocs = atLeast(100);
- final int dim = 30;
- for (int i = 0; i < numDocs; ++i) {
- Document d = new Document();
- d.add(new StringField("index", String.valueOf(i), Field.Store.YES));
- if (frequently()) {
- d.add(new KnnVectorField("vector", randomVector(dim)));
- }
- w.addDocument(d);
- }
- w.commit();
-
- // Delete some documents at random, both those with and without vectors
- Set<Term> toDelete = new HashSet<>();
- for (int i = 0; i < 25; i++) {
- int index = random().nextInt(numDocs);
- toDelete.add(new Term("index", String.valueOf(index)));
- }
- w.deleteDocuments(toDelete.toArray(new Term[0]));
- w.commit();
-
- int hits = 50;
- try (IndexReader reader = DirectoryReader.open(dir)) {
- Set<String> allIds = new HashSet<>();
- IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), hits);
- TopDocs topDocs = searcher.search(query, numDocs);
- StoredFields storedFields = reader.storedFields();
- for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
- Document doc = storedFields.document(scoreDoc.doc, Set.of("index"));
- String index = doc.get("index");
- assertFalse(
- "search returned a deleted document: " + index,
- toDelete.contains(new Term("index", index)));
- allIds.add(index);
- }
- assertEquals("search missed some documents", hits, allIds.size());
- }
- }
- }
-
- public void testAllDeletes() throws IOException {
- try (Directory dir = newDirectory();
- IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
- final int numDocs = atLeast(100);
- final int dim = 30;
- for (int i = 0; i < numDocs; ++i) {
- Document d = new Document();
- d.add(new KnnVectorField("vector", randomVector(dim)));
- w.addDocument(d);
- }
- w.commit();
-
- w.deleteDocuments(new MatchAllDocsQuery());
- w.commit();
-
- try (IndexReader reader = DirectoryReader.open(dir)) {
- IndexSearcher searcher = new IndexSearcher(reader);
- KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs);
- TopDocs topDocs = searcher.search(query, numDocs);
- assertEquals(0, topDocs.scoreDocs.length);
- }
- }
- }
-
- /**
- * Check that the query behaves reasonably when using a custom filter reader where there are no
- * live docs.
- */
- public void testNoLiveDocsReader() throws IOException {
- IndexWriterConfig iwc = newIndexWriterConfig();
- try (Directory dir = newDirectory();
- IndexWriter w = new IndexWriter(dir, iwc)) {
- final int numDocs = 10;
- final int dim = 30;
- for (int i = 0; i < numDocs; ++i) {
- Document d = new Document();
- d.add(new StringField("index", String.valueOf(i), Field.Store.NO));
- d.add(new KnnVectorField("vector", randomVector(dim)));
- w.addDocument(d);
- }
- w.commit();
-
- try (DirectoryReader reader = DirectoryReader.open(dir)) {
- DirectoryReader wrappedReader = new NoLiveDocsDirectoryReader(reader);
- IndexSearcher searcher = new IndexSearcher(wrappedReader);
- KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs);
- TopDocs topDocs = searcher.search(query, numDocs);
- assertEquals(0, topDocs.scoreDocs.length);
- }
- }
- }
-
- /**
- * Test that KnnVectorQuery optimizes the case where the filter query is backed by {@link
- * BitSetIterator}.
- */
- public void testBitSetQuery() throws IOException {
- IndexWriterConfig iwc = newIndexWriterConfig();
- try (Directory dir = newDirectory();
- IndexWriter w = new IndexWriter(dir, iwc)) {
- final int numDocs = 100;
- final int dim = 30;
- for (int i = 0; i < numDocs; ++i) {
- Document d = new Document();
- d.add(new KnnVectorField("vector", randomVector(dim)));
- w.addDocument(d);
- }
- w.commit();
-
- try (DirectoryReader reader = DirectoryReader.open(dir)) {
- IndexSearcher searcher = new IndexSearcher(reader);
-
- Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
- expectThrows(
- UnsupportedOperationException.class,
- () ->
- searcher.search(
- new KnnVectorQuery("vector", randomVector(dim), 10, filter), numDocs));
- }
- }
- }
-
- /** Creates a new directory and adds documents with the given vectors as kNN vector fields */
- private Directory getIndexStore(String field, float[]... contents) throws IOException {
- Directory indexStore = newDirectory();
- RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
- VectorEncoding encoding = randomVectorEncoding();
- for (int i = 0; i < contents.length; ++i) {
- Document doc = new Document();
- if (encoding == VectorEncoding.BYTE) {
- BytesRef v = new BytesRef(new byte[contents[i].length]);
- for (int j = 0; j < v.length; j++) {
- v.bytes[j] = (byte) contents[i][j];
- }
- doc.add(new KnnVectorField(field, v, EUCLIDEAN));
- } else {
- doc.add(new KnnVectorField(field, contents[i]));
- }
- doc.add(new StringField("id", "id" + i, Field.Store.YES));
- writer.addDocument(doc);
- }
- // Add some documents without a vector
- for (int i = 0; i < 5; i++) {
- Document doc = new Document();
- doc.add(new StringField("other", "value", Field.Store.NO));
- writer.addDocument(doc);
- }
- writer.close();
- return indexStore;
- }
-
- /**
- * Creates a new directory and adds documents with the given vectors as kNN vector fields,
- * preserving the order of the added documents.
- */
- private Directory getStableIndexStore(String field, float[]... contents) throws IOException {
- Directory indexStore = newDirectory();
- try (IndexWriter writer = new IndexWriter(indexStore, new IndexWriterConfig())) {
- VectorEncoding encoding = randomVectorEncoding();
- for (int i = 0; i < contents.length; ++i) {
- Document doc = new Document();
- if (encoding == VectorEncoding.BYTE) {
- BytesRef v = new BytesRef(new byte[contents[i].length]);
- for (int j = 0; j < v.length; j++) {
- v.bytes[j] = (byte) contents[i][j];
- }
- doc.add(new KnnVectorField(field, v, EUCLIDEAN));
- } else {
- doc.add(new KnnVectorField(field, contents[i]));
- }
- doc.add(new StringField("id", "id" + i, Field.Store.YES));
- writer.addDocument(doc);
- }
- // Add some documents without a vector
- for (int i = 0; i < 5; i++) {
- Document doc = new Document();
- doc.add(new StringField("other", "value", Field.Store.NO));
- writer.addDocument(doc);
- }
- }
- return indexStore;
- }
-
- private void assertMatches(IndexSearcher searcher, Query q, int expectedMatches)
- throws IOException {
- ScoreDoc[] result = searcher.search(q, 1000).scoreDocs;
- assertEquals(expectedMatches, result.length);
- }
-
- private void assertIdMatches(IndexReader reader, String expectedId, ScoreDoc scoreDoc)
- throws IOException {
- String actualId = reader.storedFields().document(scoreDoc.doc).get("id");
- assertEquals(expectedId, actualId);
- }
-
- /**
- * A version of {@link KnnVectorQuery} that throws an error when an exact search is run. This
- * allows us to check what search strategy is being used.
- */
private static class ThrowingKnnVectorQuery extends KnnVectorQuery {
public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) {
@@ -888,109 +167,10 @@ public class TestKnnVectorQuery extends LuceneTestCase {
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) {
throw new UnsupportedOperationException("exact search is not supported");
}
- }
-
- private static class NoLiveDocsDirectoryReader extends FilterDirectoryReader {
-
- private NoLiveDocsDirectoryReader(DirectoryReader in) throws IOException {
- super(
- in,
- new SubReaderWrapper() {
- @Override
- public LeafReader wrap(LeafReader reader) {
- return new NoLiveDocsLeafReader(reader);
- }
- });
- }
-
- @Override
- protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException {
- return new NoLiveDocsDirectoryReader(in);
- }
-
- @Override
- public CacheHelper getReaderCacheHelper() {
- return in.getReaderCacheHelper();
- }
- }
-
- private static class NoLiveDocsLeafReader extends FilterLeafReader {
- private NoLiveDocsLeafReader(LeafReader in) {
- super(in);
- }
-
- @Override
- public int numDocs() {
- return 0;
- }
-
- @Override
- public Bits getLiveDocs() {
- return new Bits.MatchNoBits(in.maxDoc());
- }
-
- @Override
- public CacheHelper getReaderCacheHelper() {
- return in.getReaderCacheHelper();
- }
-
- @Override
- public CacheHelper getCoreCacheHelper() {
- return in.getCoreCacheHelper();
- }
- }
-
- private static class ThrowingBitSetQuery extends Query {
-
- private final FixedBitSet docs;
-
- ThrowingBitSetQuery(FixedBitSet docs) {
- this.docs = docs;
- }
-
- @Override
- public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
- throws IOException {
- return new ConstantScoreWeight(this, boost) {
- @Override
- public Scorer scorer(LeafReaderContext context) throws IOException {
- BitSetIterator bitSetIterator =
- new BitSetIterator(docs, docs.approximateCardinality()) {
- @Override
- public BitSet getBitSet() {
- throw new UnsupportedOperationException("reusing BitSet is not supported");
- }
- };
- return new ConstantScoreScorer(this, score(), scoreMode, bitSetIterator);
- }
-
- @Override
- public boolean isCacheable(LeafReaderContext ctx) {
- return false;
- }
- };
- }
-
- @Override
- public void visit(QueryVisitor visitor) {}
@Override
public String toString(String field) {
- return "throwingBitSetQuery";
+ return null;
}
-
- @Override
- public boolean equals(Object other) {
- return sameClassAs(other) && docs.equals(((ThrowingBitSetQuery) other).docs);
- }
-
- @Override
- public int hashCode() {
- return 31 * classHash() + docs.hashCode();
- }
- }
-
- private VectorEncoding randomVectorEncoding() {
- return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
index a852cdcbd45..76e1b9596ed 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
@@ -172,6 +172,17 @@ public class TestVectorUtil extends LuceneTestCase {
return v;
}
+ public static BytesRef randomVectorBytes(int dim) {
+ BytesRef v = TestUtil.randomBinaryTerm(random(), dim);
+ // clip at -127 to avoid overflow
+ for (int i = v.offset; i < v.offset + v.length; i++) {
+ if (v.bytes[i] == -128) {
+ v.bytes[i] = -127;
+ }
+ }
+ return v;
+ }
+
public void testBasicDotProductBytes() {
BytesRef a = new BytesRef(new byte[] {1, 2, 3});
BytesRef b = new BytesRef(new byte[] {-10, 0, 5});
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index 8c427816014..e3cfa2d462b 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -281,16 +281,35 @@ public class TestHnswGraph extends LuceneTestCase {
vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// run some searches
- NeighborQueue nn =
- HnswGraphSearcher.search(
- getTargetVector(),
- 10,
- vectors.copy(),
- vectorEncoding,
- similarityFunction,
- hnsw,
- null,
- Integer.MAX_VALUE);
+ final NeighborQueue nn;
+ switch (vectorEncoding) {
+ case FLOAT32:
+ nn =
+ HnswGraphSearcher.search(
+ getTargetVector(),
+ 10,
+ vectors.copy(),
+ vectorEncoding,
+ similarityFunction,
+ hnsw,
+ null,
+ Integer.MAX_VALUE);
+ break;
+ case BYTE:
+ nn =
+ HnswGraphSearcher.search(
+ getTargetByteVector(),
+ 10,
+ vectors.copy(),
+ vectorEncoding,
+ similarityFunction,
+ hnsw,
+ null,
+ Integer.MAX_VALUE);
+ break;
+ default:
+ throw new IllegalArgumentException("unexpected vector encoding: " + vectorEncoding);
+ }
int[] nodes = nn.nodes();
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
@@ -323,16 +342,35 @@ public class TestHnswGraph extends LuceneTestCase {
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
- NeighborQueue nn =
- HnswGraphSearcher.search(
- getTargetVector(),
- 10,
- vectors.copy(),
- vectorEncoding,
- similarityFunction,
- hnsw,
- acceptOrds,
- Integer.MAX_VALUE);
+ final NeighborQueue nn;
+ switch (vectorEncoding) {
+ case FLOAT32:
+ nn =
+ HnswGraphSearcher.search(
+ getTargetVector(),
+ 10,
+ vectors.copy(),
+ vectorEncoding,
+ similarityFunction,
+ hnsw,
+ acceptOrds,
+ Integer.MAX_VALUE);
+ break;
+ case BYTE:
+ nn =
+ HnswGraphSearcher.search(
+ getTargetByteVector(),
+ 10,
+ vectors.copy(),
+ vectorEncoding,
+ similarityFunction,
+ hnsw,
+ acceptOrds,
+ Integer.MAX_VALUE);
+ break;
+ default:
+ throw new IllegalArgumentException("unexpected vector encoding: " + vectorEncoding);
+ }
int[] nodes = nn.nodes();
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0;
@@ -362,16 +400,35 @@ public class TestHnswGraph extends LuceneTestCase {
// Check the search finds all accepted vectors
int numAccepted = acceptOrds.cardinality();
- NeighborQueue nn =
- HnswGraphSearcher.search(
- getTargetVector(),
- numAccepted,
- vectors.copy(),
- vectorEncoding,
- similarityFunction,
- hnsw,
- acceptOrds,
- Integer.MAX_VALUE);
+ final NeighborQueue nn;
+ switch (vectorEncoding) {
+ case FLOAT32:
+ nn =
+ HnswGraphSearcher.search(
+ getTargetVector(),
+ numAccepted,
+ vectors.copy(),
+ vectorEncoding,
+ similarityFunction,
+ hnsw,
+ acceptOrds,
+ Integer.MAX_VALUE);
+ break;
+ case BYTE:
+ nn =
+ HnswGraphSearcher.search(
+ getTargetByteVector(),
+ numAccepted,
+ vectors.copy(),
+ vectorEncoding,
+ similarityFunction,
+ hnsw,
+ acceptOrds,
+ Integer.MAX_VALUE);
+ break;
+ default:
+ throw new IllegalArgumentException("unexpected vector encoding: " + vectorEncoding);
+ }
int[] nodes = nn.nodes();
assertEquals(numAccepted, nodes.length);
for (int node : nodes) {
@@ -383,6 +440,10 @@ public class TestHnswGraph extends LuceneTestCase {
return new float[] {1, 0};
}
+ private BytesRef getTargetByteVector() {
+ return new BytesRef(new byte[] {1, 0});
+ }
+
public void testSearchWithSkewedAcceptOrds() throws IOException {
int nDoc = 1000;
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
@@ -431,16 +492,36 @@ public class TestHnswGraph extends LuceneTestCase {
int topK = 50;
int visitedLimit = topK + random().nextInt(5);
- NeighborQueue nn =
- HnswGraphSearcher.search(
- getTargetVector(),
- topK,
- vectors.copy(),
- vectorEncoding,
- similarityFunction,
- hnsw,
- createRandomAcceptOrds(0, vectors.size),
- visitedLimit);
+ final NeighborQueue nn;
+ switch (vectorEncoding) {
+ case FLOAT32:
+ nn =
+ HnswGraphSearcher.search(
+ getTargetVector(),
+ topK,
+ vectors.copy(),
+ vectorEncoding,
+ similarityFunction,
+ hnsw,
+ createRandomAcceptOrds(0, vectors.size),
+ visitedLimit);
+ break;
+ case BYTE:
+ nn =
+ HnswGraphSearcher.search(
+ getTargetByteVector(),
+ topK,
+ vectors.copy(),
+ vectorEncoding,
+ similarityFunction,
+ hnsw,
+ createRandomAcceptOrds(0, vectors.size),
+ visitedLimit);
+ break;
+ default:
+ throw new IllegalArgumentException("unexpected vector encoding: " + vectorEncoding);
+ }
+
assertTrue(nn.incomplete());
// The visited count shouldn't exceed the limit
assertTrue(nn.visitedCount() <= visitedLimit);
@@ -654,7 +735,6 @@ public class TestHnswGraph extends LuceneTestCase {
int totalMatches = 0;
for (int i = 0; i < 100; i++) {
- NeighborQueue actual;
float[] query;
BytesRef bQuery = null;
if (vectorEncoding == VectorEncoding.BYTE) {
@@ -663,16 +743,35 @@ public class TestHnswGraph extends LuceneTestCase {
} else {
query = randomVector(random(), dim);
}
- actual =
- HnswGraphSearcher.search(
- query,
- 100,
- vectors,
- vectorEncoding,
- similarityFunction,
- hnsw,
- acceptOrds,
- Integer.MAX_VALUE);
+ final NeighborQueue actual;
+ switch (vectorEncoding) {
+ case BYTE:
+ actual =
+ HnswGraphSearcher.search(
+ bQuery,
+ 100,
+ vectors,
+ vectorEncoding,
+ similarityFunction,
+ hnsw,
+ acceptOrds,
+ Integer.MAX_VALUE);
+ break;
+ case FLOAT32:
+ actual =
+ HnswGraphSearcher.search(
+ query,
+ 100,
+ vectors,
+ vectorEncoding,
+ similarityFunction,
+ hnsw,
+ acceptOrds,
+ Integer.MAX_VALUE);
+ break;
+ default:
+ throw new IllegalArgumentException("unexpected vector encoding: " + vectorEncoding);
+ }
while (actual.size() > topK) {
actual.pop();
}
diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java
index ca044d91318..7e71720c6aa 100644
--- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java
+++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java
@@ -41,6 +41,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;
/**
@@ -170,6 +171,12 @@ public class TermVectorLeafReader extends LeafReader {
return null;
}
+ @Override
+ public TopDocs searchNearestVectors(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
+ return null;
+ }
+
@Override
public void checkIntegrity() throws IOException {}
diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
index 39898fbee87..36a330829b2 100644
--- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
+++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
@@ -1402,6 +1402,12 @@ public class MemoryIndex {
return null;
}
+ @Override
+ public TopDocs searchNearestVectors(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
+ return null;
+ }
+
@Override
public void checkIntegrity() throws IOException {
// no-op
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
index 4224f66fe6b..dd8f2cdbaf5 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
@@ -28,10 +28,12 @@ import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
/** Wraps the default KnnVectorsFormat and provides additional assertions. */
public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
@@ -124,7 +126,22 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
FieldInfo fi = fis.fieldInfo(field);
- assert fi != null && fi.getVectorDimension() > 0;
+ assert fi != null
+ && fi.getVectorDimension() > 0
+ && fi.getVectorEncoding() == VectorEncoding.FLOAT32;
+ TopDocs hits = delegate.search(field, target, k, acceptDocs, visitedLimit);
+ assert hits != null;
+ assert hits.scoreDocs.length <= k;
+ return hits;
+ }
+
+ @Override
+ public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
+ throws IOException {
+ FieldInfo fi = fis.fieldInfo(field);
+ assert fi != null
+ && fi.getVectorDimension() > 0
+ && fi.getVectorEncoding() == VectorEncoding.BYTE;
TopDocs hits = delegate.search(field, target, k, acceptDocs, visitedLimit);
assert hits != null;
assert hits.scoreDocs.length <= k;
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java
index a4822f073ff..e54721fca00 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java
@@ -44,6 +44,7 @@ import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
/**
* This is a hack to make index sorting fast, with a {@link LeafReader} that always returns merge
@@ -240,6 +241,12 @@ class MergeReaderWrapper extends LeafReader {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
+ @Override
+ public TopDocs searchNearestVectors(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
+ return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
+ }
+
@Override
public int numDocs() {
return in.numDocs();
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java
index 3c697dcfcf3..6e01564cdcd 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java
@@ -55,6 +55,7 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;
import org.junit.Assert;
@@ -235,6 +236,12 @@ public class QueryUtils {
return null;
}
+ @Override
+ public TopDocs searchNearestVectors(
+ String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) {
+ return null;
+ }
+
@Override
public FieldInfos getFieldInfos() {
return FieldInfos.EMPTY;