You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ju...@apache.org on 2021/08/16 14:46:26 UTC
[lucene] branch main updated: LUCENE-10040: Handle deletions in
nearest vector search (#239)
This is an automated email from the ASF dual-hosted git repository.
julietibs pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/main by this push:
new 6993fb9 LUCENE-10040: Handle deletions in nearest vector search (#239)
6993fb9 is described below
commit 6993fb9a9985372b0f0984b8bdd7434aaa33ad26
Author: Julie Tibshirani <ju...@gmail.com>
AuthorDate: Mon Aug 16 17:44:17 2021 +0300
LUCENE-10040: Handle deletions in nearest vector search (#239)
This PR extends VectorReader#search to take a parameter specifying the live
docs. LeafReader#searchNearestVectors then always returns the k nearest
undeleted docs.
To implement this, the HNSW algorithm will only add a candidate to the result
set if it is a live doc. The graph search still visits and traverses deleted
docs as it gathers candidates.
---
lucene/CHANGES.txt | 4 +-
.../simpletext/SimpleTextKnnVectorsReader.java | 3 +-
.../org/apache/lucene/codecs/KnnVectorsFormat.java | 3 +-
.../org/apache/lucene/codecs/KnnVectorsReader.java | 6 +-
.../codecs/lucene90/Lucene90HnswVectorsReader.java | 21 +++++-
.../codecs/perfield/PerFieldKnnVectorsFormat.java | 5 +-
.../java/org/apache/lucene/index/CodecReader.java | 5 +-
.../apache/lucene/index/DocValuesLeafReader.java | 3 +-
.../org/apache/lucene/index/FilterLeafReader.java | 5 +-
.../java/org/apache/lucene/index/LeafReader.java | 4 +-
.../apache/lucene/index/MergeReaderWrapper.java | 5 +-
.../apache/lucene/index/ParallelLeafReader.java | 5 +-
.../lucene/index/SlowCodecReaderWrapper.java | 5 +-
.../apache/lucene/index/SortingCodecReader.java | 2 +-
.../org/apache/lucene/search/KnnVectorQuery.java | 4 +-
.../org/apache/lucene/util/hnsw/HnswGraph.java | 23 +++++--
.../apache/lucene/util/hnsw/HnswGraphBuilder.java | 3 +-
.../org/apache/lucene/util/hnsw/NeighborQueue.java | 7 --
.../perfield/TestPerFieldKnnVectorsFormat.java | 15 ++---
.../test/org/apache/lucene/index/TestKnnGraph.java | 4 +-
.../lucene/index/TestSegmentToThreadMapping.java | 2 +-
.../apache/lucene/search/TestKnnVectorQuery.java | 74 ++++++++++++++++++++++
.../apache/lucene/util/hnsw/KnnGraphTester.java | 4 +-
.../hnsw/{TestHnsw.java => TestHnswGraph.java} | 58 +++++++++++++++--
.../search/highlight/TermVectorLeafReader.java | 2 +-
.../apache/lucene/index/memory/MemoryIndex.java | 2 +-
.../asserting/AssertingKnnVectorsFormat.java | 5 +-
.../java/org/apache/lucene/search/QueryUtils.java | 2 +-
28 files changed, 222 insertions(+), 59 deletions(-)
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 6460112..5f970eb 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -7,9 +7,9 @@ http://s.apache.org/luceneversions
New Features
-* LUCENE-9322 LUCENE-9855: Vector-valued fields, Lucene90 Codec (Mike Sokolov, Julie Tibshirani, Tomoko Uchida)
+* LUCENE-9322, LUCENE-9855: Vector-valued fields, Lucene90 Codec (Mike Sokolov, Julie Tibshirani, Tomoko Uchida)
-* LUCENE-9004: Approximate nearest vector search via NSW graphs (Mike Sokolov, Tomoko Uchida et al.)
+* LUCENE-9004, LUCENE-10040: Approximate nearest vector search via NSW graphs (Mike Sokolov, Tomoko Uchida et al.)
* LUCENE-9659: SpanPayloadCheckQuery now supports inequalities. (Kevin Watters, Gus Heck)
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
index dcc8518..7fdf266 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
@@ -37,6 +37,7 @@ import org.apache.lucene.store.BufferedChecksumIndexInput;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IOUtils;
@@ -138,7 +139,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
@Override
- public TopDocs search(String field, float[] target, int k) throws IOException {
+ public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
throw new UnsupportedOperationException();
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java
index 3d0f264..4b58f2d 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java
@@ -23,6 +23,7 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
+import org.apache.lucene.util.Bits;
import org.apache.lucene.util.NamedSPILoader;
/**
@@ -99,7 +100,7 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
}
@Override
- public TopDocs search(String field, float[] target, int k) {
+ public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
index beca006..b692ace 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
@@ -22,6 +22,7 @@ import java.io.IOException;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Accountable;
+import org.apache.lucene.util.Bits;
/** Reads vectors from an index. */
public abstract class KnnVectorsReader implements Closeable, Accountable {
@@ -51,9 +52,12 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
+ * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
+ * if they are all allowed to match.
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
*/
- public abstract TopDocs search(String field, float[] target, int k) throws IOException;
+ public abstract TopDocs search(String field, float[] target, int k, Bits acceptDocs)
+ throws IOException;
/**
* Returns an instance optimized for merging. This instance may only be consumed in the thread
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
index 6a69ab9..70e386d 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -43,6 +43,7 @@ import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
@@ -232,7 +233,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
@Override
- public TopDocs search(String field, float[] target, int k) throws IOException {
+ public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.dimension == 0) {
return null;
@@ -250,6 +251,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
vectorValues,
fieldEntry.similarityFunction,
getGraphValues(fieldEntry),
+ getAcceptOrds(acceptDocs, fieldEntry),
random);
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
@@ -276,6 +278,23 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return new OffHeapVectorValues(fieldEntry, bytesSlice);
}
+ private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
+ if (acceptDocs == null) {
+ return null;
+ }
+ return new Bits() {
+ @Override
+ public boolean get(int index) {
+ return acceptDocs.get(fieldEntry.ordToDoc[index]);
+ }
+
+ @Override
+ public int length() {
+ return fieldEntry.ordToDoc.length;
+ }
+ };
+ }
+
public KnnGraphValues getGraphValues(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
index 060f032..0e5cb00 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
@@ -33,6 +33,7 @@ import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
/**
@@ -240,12 +241,12 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
- public TopDocs search(String field, float[] target, int k) throws IOException {
+ public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field);
if (knnVectorsReader == null) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
} else {
- return knnVectorsReader.search(field, target, k);
+ return knnVectorsReader.search(field, target, k, acceptDocs);
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java
index b25087a..6942051 100644
--- a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java
@@ -26,6 +26,7 @@ import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.util.Bits;
/** LeafReader implemented by codec APIs. */
public abstract class CodecReader extends LeafReader {
@@ -211,7 +212,7 @@ public abstract class CodecReader extends LeafReader {
}
@Override
- public final TopDocs searchNearestVectors(String field, float[] target, int k)
+ public final TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
@@ -220,7 +221,7 @@ public abstract class CodecReader extends LeafReader {
return null;
}
- return getVectorReader().search(field, target, k);
+ return getVectorReader().search(field, target, k, acceptDocs);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java
index 4f0ace4..f618c6c 100644
--- a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java
@@ -53,7 +53,8 @@ abstract class DocValuesLeafReader extends LeafReader {
}
@Override
- public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException {
+ public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
+ throws IOException {
throw new UnsupportedOperationException();
}
diff --git a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
index c08a559..cba9b99 100644
--- a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
@@ -345,8 +345,9 @@ public abstract class FilterLeafReader extends LeafReader {
}
@Override
- public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException {
- return in.searchNearestVectors(field, target, k);
+ public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
+ throws IOException {
+ return in.searchNearestVectors(field, target, k, acceptDocs);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java
index 95f4176..729db64 100644
--- a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java
@@ -222,10 +222,12 @@ public abstract class LeafReader extends IndexReader {
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
+ * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
+ * if they are all allowed to match.
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
* @lucene.experimental
*/
- public abstract TopDocs searchNearestVectors(String field, float[] target, int k)
+ public abstract TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
throws IOException;
/**
diff --git a/lucene/core/src/java/org/apache/lucene/index/MergeReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/MergeReaderWrapper.java
index 3d925de..ef4d462 100644
--- a/lucene/core/src/java/org/apache/lucene/index/MergeReaderWrapper.java
+++ b/lucene/core/src/java/org/apache/lucene/index/MergeReaderWrapper.java
@@ -209,8 +209,9 @@ class MergeReaderWrapper extends LeafReader {
}
@Override
- public TopDocs searchNearestVectors(String field, float[] target, int k) throws IOException {
- return in.searchNearestVectors(field, target, k);
+ public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs)
+ throws IOException {
+ return in.searchNearestVectors(field, target, k, acceptDocs);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java
index 6ab727b..c8d1005 100644
--- a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java
@@ -398,10 +398,11 @@ public class ParallelLeafReader extends LeafReader {
}
@Override
- public TopDocs searchNearestVectors(String fieldName, float[] target, int k) throws IOException {
+ public TopDocs searchNearestVectors(String fieldName, float[] target, int k, Bits acceptDocs)
+ throws IOException {
ensureOpen();
LeafReader reader = fieldToReader.get(fieldName);
- return reader == null ? null : reader.searchNearestVectors(fieldName, target, k);
+ return reader == null ? null : reader.searchNearestVectors(fieldName, target, k, acceptDocs);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java
index de62965..3363dc0 100644
--- a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java
+++ b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java
@@ -167,8 +167,9 @@ public final class SlowCodecReaderWrapper {
}
@Override
- public TopDocs search(String field, float[] target, int k) throws IOException {
- return reader.searchNearestVectors(field, target, k);
+ public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
+ throws IOException {
+ return reader.searchNearestVectors(field, target, k, acceptDocs);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
index 479df73..f808c90 100644
--- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
@@ -315,7 +315,7 @@ public final class SortingCodecReader extends FilterCodecReader {
}
@Override
- public TopDocs search(String field, float[] target, int k) {
+ public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
throw new UnsupportedOperationException();
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
index 5dccb80..6050920 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
@@ -26,6 +26,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.util.Bits;
/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
public class KnnVectorQuery extends Query {
@@ -70,7 +71,8 @@ public class KnnVectorQuery extends Query {
}
private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
- TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+ Bits liveDocs = ctx.reader().getLiveDocs();
+ TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf, liveDocs);
if (results == null) {
return NO_RESULTS;
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
index 49f2c95..d1f0420 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
@@ -26,6 +26,7 @@ import java.util.Random;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SparseFixedBitSet;
/**
@@ -83,6 +84,8 @@ public final class HnswGraph extends KnnGraphValues {
* @param vectors vector values
* @param graphValues the graph values. May represent the entire graph, or a level in a
* hierarchical graph.
+ * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
+ * {@code null} if they are all allowed to match.
* @param random a source of randomness, used for generating entry points to the graph
* @return a priority queue holding the closest neighbors found
*/
@@ -93,12 +96,15 @@ public final class HnswGraph extends KnnGraphValues {
RandomAccessVectorValues vectors,
VectorSimilarityFunction similarityFunction,
KnnGraphValues graphValues,
+ Bits acceptOrds,
Random random)
throws IOException {
int size = graphValues.size();
// MIN heap, holding the top results
NeighborQueue results = new NeighborQueue(numSeed, similarityFunction.reversed);
+ // MAX heap, from which to pull the candidate nodes
+ NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
// set of ordinals that have been visited by search on this layer, used to avoid backtracking
SparseFixedBitSet visited = new SparseFixedBitSet(size);
@@ -109,13 +115,14 @@ public final class HnswGraph extends KnnGraphValues {
if (visited.get(entryPoint) == false) {
visited.set(entryPoint);
// explore the topK starting points of some random numSeed probes
- results.add(entryPoint, similarityFunction.compare(query, vectors.vectorValue(entryPoint)));
+ float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint));
+ candidates.add(entryPoint, score);
+ if (acceptOrds == null || acceptOrds.get(entryPoint)) {
+ results.add(entryPoint, score);
+ }
}
}
- // MAX heap, from which to pull the candidate nodes
- NeighborQueue candidates = results.copy(!similarityFunction.reversed);
-
// Set the bound to the worst current result and below reject any newly-generated candidates
// failing
// to exceed this bound
@@ -138,10 +145,14 @@ public final class HnswGraph extends KnnGraphValues {
continue;
}
visited.set(friendOrd);
+
float score = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
- if (results.insertWithOverflow(friendOrd, score)) {
+ if (results.size() < numSeed || bound.check(score) == false) {
candidates.add(friendOrd, score);
- bound.set(results.topScore());
+ if (acceptOrds == null || acceptOrds.get(friendOrd)) {
+ results.insertWithOverflow(friendOrd, score);
+ bound.set(results.topScore());
+ }
}
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index c7ff31a..d12a731 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -134,9 +134,10 @@ public final class HnswGraphBuilder {
/** Inserts a doc with vector value to the graph */
void addGraphNode(float[] value) throws IOException {
+ // We pass 'null' for acceptOrds because there are no deletions while building the graph
NeighborQueue candidates =
HnswGraph.search(
- value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, random);
+ value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
int node = hnsw.addNode();
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
index bab361b..4102dff 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
@@ -42,13 +42,6 @@ public class NeighborQueue {
}
}
- NeighborQueue copy(boolean reversed) {
- int size = size();
- NeighborQueue copy = new NeighborQueue(size, reversed);
- copy.heap.pushAll(heap);
- return copy;
- }
-
/** @return the number of elements in the heap */
public int size() {
return heap.size();
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
index b170b0d..bc0eae5 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
@@ -38,6 +38,7 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.RandomCodec;
import org.apache.lucene.index.SegmentReadState;
@@ -101,19 +102,13 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
// Double-check the vectors were written
try (IndexReader ireader = DirectoryReader.open(directory)) {
+ LeafReader reader = ireader.leaves().get(0).reader();
TopDocs hits1 =
- ireader
- .leaves()
- .get(0)
- .reader()
- .searchNearestVectors("field1", new float[] {1, 2, 3}, 10);
+ reader.searchNearestVectors("field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs());
assertEquals(1, hits1.scoreDocs.length);
+
TopDocs hits2 =
- ireader
- .leaves()
- .get(0)
- .reader()
- .searchNearestVectors("field2", new float[] {1, 2, 3}, 10);
+ reader.searchNearestVectors("field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs());
assertEquals(1, hits2.scoreDocs.length);
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index ba43b79..b035a2f 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -42,6 +42,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
@@ -291,7 +292,8 @@ public class TestKnnGraph extends LuceneTestCase {
private static TopDocs doKnnSearch(IndexReader reader, float[] vector, int k) throws IOException {
TopDocs[] results = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx : reader.leaves()) {
- results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k);
+ Bits liveDocs = ctx.reader().getLiveDocs();
+ results[ctx.ord] = ctx.reader().searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs);
if (ctx.docBase > 0) {
for (ScoreDoc doc : results[ctx.ord].scoreDocs) {
doc.doc += ctx.docBase;
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java
index 74888fd..c76968f 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java
@@ -112,7 +112,7 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
}
@Override
- public TopDocs searchNearestVectors(String field, float[] target, int k) {
+ public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
return null;
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
index 862f8f7..6443c86 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
@@ -16,10 +16,13 @@
*/
package org.apache.lucene.search;
+import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.TestVectorUtil.randomVector;
import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField;
@@ -303,6 +306,77 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
}
+ public void testDeletes() throws IOException {
+ try (Directory dir = newDirectory();
+ IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ final int numDocs = atLeast(100);
+ final int dim = 30;
+ int docIndex = 0;
+ for (int i = 0; i < numDocs; ++i) {
+ Document d = new Document();
+ if (frequently()) {
+ d.add(new StringField("index", String.valueOf(docIndex), Field.Store.YES));
+ d.add(new KnnVectorField("vector", randomVector(dim)));
+ docIndex++;
+ } else {
+ d.add(new StringField("other", "value" + (i % 5), Field.Store.NO));
+ }
+ w.addDocument(d);
+ }
+ w.commit();
+
+ // Delete some documents at random, both those with and without vectors
+ Set<Term> toDelete = new HashSet<>();
+ for (int i = 0; i < 20; i++) {
+ int index = random().nextInt(docIndex);
+ toDelete.add(new Term("index", String.valueOf(index)));
+ }
+ w.deleteDocuments(toDelete.toArray(new Term[0]));
+ w.deleteDocuments(new Term("other", "value" + random().nextInt(5)));
+ w.commit();
+
+ try (IndexReader reader = DirectoryReader.open(dir)) {
+ Set<String> allIds = new HashSet<>();
+ IndexSearcher searcher = new IndexSearcher(reader);
+ KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs);
+ TopDocs topDocs = searcher.search(query, numDocs);
+ for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
+ Document doc = reader.document(scoreDoc.doc, Set.of("index"));
+ String index = doc.get("index");
+ assertFalse(
+ "search returned a deleted document: " + index,
+ toDelete.contains(new Term("index", index)));
+ allIds.add(index);
+ }
+ assertEquals("search missed some documents", docIndex - toDelete.size(), allIds.size());
+ }
+ }
+ }
+
+ public void testAllDeletes() throws IOException {
+ try (Directory dir = newDirectory();
+ IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ final int numDocs = atLeast(100);
+ final int dim = 30;
+ for (int i = 0; i < numDocs; ++i) {
+ Document d = new Document();
+ d.add(new KnnVectorField("vector", randomVector(dim)));
+ w.addDocument(d);
+ }
+ w.commit();
+
+ w.deleteDocuments(new MatchAllDocsQuery());
+ w.commit();
+
+ try (IndexReader reader = DirectoryReader.open(dir)) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(dim), numDocs);
+ TopDocs topDocs = searcher.search(query, numDocs);
+ assertEquals(0, topDocs.scoreDocs.length);
+ }
+ }
+ }
+
private Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index caf94fe..fcdf0aa 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -58,6 +58,7 @@ import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
+import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.PrintStreamInfoStream;
@@ -424,7 +425,8 @@ public class KnnGraphTester {
IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException {
TopDocs[] results = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx : reader.leaves()) {
- results[ctx.ord] = ctx.reader().searchNearestVectors(field, vector, k + fanout);
+ Bits liveDocs = ctx.reader().getLiveDocs();
+ results[ctx.ord] = ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs);
int docBase = ctx.docBase;
for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
scoreDoc.doc += docBase;
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
similarity index 89%
rename from lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
rename to lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index 676cae8..bec0541 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -45,12 +45,14 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
/** Tests HNSW KNN graphs */
-public class TestHnsw extends LuceneTestCase {
+public class TestHnswGraph extends LuceneTestCase {
// test writing out and reading in a graph gives the expected graph
public void testReadWrite() throws IOException {
@@ -138,6 +140,7 @@ public class TestHnsw extends LuceneTestCase {
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
+ null,
random());
int sum = 0;
for (int node : nn.nodes()) {
@@ -156,6 +159,35 @@ public class TestHnsw extends LuceneTestCase {
}
}
+ public void testSearchWithAcceptOrds() throws IOException {
+ int nDoc = 100;
+ CircularVectorValues vectors = new CircularVectorValues(nDoc);
+ HnswGraphBuilder builder =
+ new HnswGraphBuilder(
+ vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
+ HnswGraph hnsw = builder.build(vectors);
+
+ Bits acceptOrds = createRandomAcceptOrds(vectors.size);
+ NeighborQueue nn =
+ HnswGraph.search(
+ new float[] {1, 0},
+ 10,
+ 5,
+ vectors.randomAccess(),
+ VectorSimilarityFunction.DOT_PRODUCT,
+ hnsw,
+ acceptOrds,
+ random());
+ int sum = 0;
+ for (int node : nn.nodes()) {
+ assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
+ sum += node;
+ }
+ // We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) =
+ // 45
+ assertTrue("sum(result docs)=" + sum, sum < 75);
+ }
+
public void testBoundsCheckerMax() {
BoundsChecker max = BoundsChecker.create(false);
float f = random().nextFloat() - 0.5f;
@@ -279,16 +311,21 @@ public class TestHnsw extends LuceneTestCase {
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
HnswGraph hnsw = builder.build(vectors);
+ Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(size);
+
int totalMatches = 0;
for (int i = 0; i < 100; i++) {
float[] query = randomVector(random(), dim);
NeighborQueue actual =
- HnswGraph.search(query, topK, 100, vectors, similarityFunction, hnsw, random());
+ HnswGraph.search(
+ query, topK, 100, vectors, similarityFunction, hnsw, acceptOrds, random());
NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
for (int j = 0; j < size; j++) {
- float[] v = vectors.vectorValue(j);
- if (v != null) {
- expected.insertWithOverflow(j, similarityFunction.compare(query, vectors.vectorValue(j)));
+ if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
+ expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
+ if (expected.size() > topK) {
+ expected.pop();
+ }
}
}
assertEquals(topK, actual.size());
@@ -455,6 +492,17 @@ public class TestHnsw extends LuceneTestCase {
}
}
+ /** Generate a random bitset where each entry has a 2/3 probability of being set. */
+ private static Bits createRandomAcceptOrds(int length) {
+ FixedBitSet bits = new FixedBitSet(length);
+ for (int i = 0; i < bits.length(); i++) {
+ if (random().nextFloat() < 0.667f) {
+ bits.set(i);
+ }
+ }
+ return bits;
+ }
+
private static float[] randomVector(Random random, int dim) {
float[] vec = new float[dim];
for (int i = 0; i < dim; i++) {
diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java
index a307d67..8a3e992 100644
--- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java
+++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java
@@ -162,7 +162,7 @@ public class TermVectorLeafReader extends LeafReader {
}
@Override
- public TopDocs searchNearestVectors(String field, float[] target, int k) {
+ public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
return null;
}
diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
index b22d485..5b06e3f 100644
--- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
+++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
@@ -1373,7 +1373,7 @@ public class MemoryIndex {
}
@Override
- public TopDocs searchNearestVectors(String field, float[] target, int k) {
+ public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
return null;
}
diff --git a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingKnnVectorsFormat.java
index 135a248..180c6df 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingKnnVectorsFormat.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingKnnVectorsFormat.java
@@ -26,6 +26,7 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.util.Bits;
import org.apache.lucene.util.TestUtil;
/** Wraps the default KnnVectorsFormat and provides additional assertions. */
@@ -98,8 +99,8 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
- public TopDocs search(String field, float[] target, int k) throws IOException {
- TopDocs hits = delegate.search(field, target, k);
+ public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
+ TopDocs hits = delegate.search(field, target, k, acceptDocs);
assert hits != null;
assert hits.scoreDocs.length <= k;
return hits;
diff --git a/lucene/test-framework/src/java/org/apache/lucene/search/QueryUtils.java b/lucene/test-framework/src/java/org/apache/lucene/search/QueryUtils.java
index 06fa4bc..48ed7b2 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/search/QueryUtils.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/search/QueryUtils.java
@@ -216,7 +216,7 @@ public class QueryUtils {
}
@Override
- public TopDocs searchNearestVectors(String field, float[] target, int k) {
+ public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) {
return null;
}