You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ma...@apache.org on 2021/10/05 10:44:14 UTC
[lucene] branch hnsw updated: Disk write and read of hnsw graph
(#315)
This is an automated email from the ASF dual-hosted git repository.
mayya pushed a commit to branch hnsw
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/hnsw by this push:
new 5e42fc2 Disk write and read of hnsw graph (#315)
5e42fc2 is described below
commit 5e42fc24a2b88a9b25e3153d433c80011200702b
Author: Mayya Sharipova <ma...@elastic.co>
AuthorDate: Tue Oct 5 06:41:21 2021 -0400
Disk write and read of hnsw graph (#315)
Disk write and read of hierarchical nsw graph.
Modify graph files as following:
- .vem - meta file
- .vec - vector data file
- .vex - vector graph index, index file for the vector graph file
- .vgr - vector graph file, stores graph connections (previously
called .vex file)
.vem file
+-------------+--------+------------------+------------------+-
| FieldNumber | SimFun | VectorDataOffset | vectorDataLength |
+-------------+--------+------------------+------------------+-
-+------------------+------------------+-----------------+-----------------+-
| graphIndexOffset | graphIndexLength | graphDataOffset | graphDataLength |
-+------------------+------------------+-----------------+-----------------+-
--+-------------+------------+------+--------+
| NumOfLevels | Dimensions | Size | DocIds |
--+-------------+------------+------+--------+
- graph offsets are moved to .vex file.
This allows to keep metada file smaller, and possibly
later to to load graph offsets on first use, or make
them off-heap.
.vec file stays the same
.vex file:
+-------------+--------------------+--------------------+---------------+--
| NumOfLevels | NumOfNodesOnLevel0 | NumOfNodesOnLevel1 | NodesOnLevel1 | ...
+-------------+--------------------+--------------------+---------------+--
-+----------------------+-----------------+-
| NumOfNodesOnLevelMax | NodesOnLevelMax |
-+----------------------+-----------------+--
--+------------------------------+----+--------------------------------+
| GraphOffsetsForNodesOnLevel0 | ...| GraphOffsetsForNodesOnLevelMax |
--+------------------------------+----+--------------------------------+
.vgr file:
+------------------------+------------------------+-----+--------------------------+
| Level0NodesConnections | Level1NodesConnections | ....| LevelMaxNodesConnections |
+------------------------+------------------------+-----+--------------------------+
---
.../codecs/lucene90/Lucene90HnswVectorsFormat.java | 56 +++-
.../codecs/lucene90/Lucene90HnswVectorsReader.java | 190 ++++++++++--
.../codecs/lucene90/Lucene90HnswVectorsWriter.java | 160 ++++++----
.../org/apache/lucene/index/KnnGraphValues.java | 21 +-
.../org/apache/lucene/util/hnsw/HnswGraph.java | 78 ++---
.../apache/lucene/util/hnsw/HnswGraphBuilder.java | 61 +---
.../test/org/apache/lucene/index/TestKnnGraph.java | 333 ++++++++++++++-------
.../apache/lucene/util/hnsw/KnnGraphTester.java | 2 +-
.../apache/lucene/util/hnsw/TestHNSWGraph2.java | 161 ----------
.../org/apache/lucene/util/hnsw/TestHnswGraph.java | 253 ++++++----------
10 files changed, 679 insertions(+), 636 deletions(-)
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsFormat.java
index 4d033ea..f1e6aab 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsFormat.java
@@ -28,21 +28,50 @@ import org.apache.lucene.util.hnsw.HnswGraph;
/**
* Lucene 9.0 vector format, which encodes numeric vector values and an optional associated graph
* connecting the documents having values. The graph is used to power HNSW search. The format
- * consists of three files:
+ * consists of four files:
*
* <h2>.vec (vector data) file</h2>
*
* <p>This file stores all the floating-point vector data ordered by field, document ordinal, and
* vector dimension. The floats are stored in little-endian byte order.
*
- * <h2>.vex (vector index) file</h2>
+ * <h2>.vex (graph index) file</h2>
*
- * <p>Stores graphs connecting the documents for each field. For each document having a vector for a
- * given field, this is stored as:
+ * <p>Stores the graph info for each vector field: graph nodes on each level and for each node a
+ * pointer to the graph data file that contains this node's neighbours. For each vector field
+ * organized as:
*
* <ul>
- * <li><b>[int32]</b> the number of neighbor nodes
- * <li><b>array[vint]</b> the neighbor ordinals, delta-encoded (initially subtracting -1)
+ * <li><b>[int]</b> the number of levels in the graph
+ * <li>For each level
+ * <ul>
+ * <li><b>[int]</b> the number of nodes on this level
+ * <li><b>array[vint]</b> for levels greater than 0 list of nodes on this level, stored as
+ * the the level 0th nodes ordinals.
+ * </ul>
+ * <li>For each level
+ * <ul>
+ * <li><b>array[vlong]</b> for each node the offset (delta-encoded relative to the previous
+ * document) of its entry in in the graph data (.veg) that stores this node's
+ * connections.
+ * </ul>
+ * </ul>
+ *
+ * <h2>.veg (graph data) file</h2>
+ *
+ * <p>Stores graphs connecting the documents for each field organized as a list of nodes' neighbours
+ * as following:
+ *
+ * <ul>
+ * <li>For each level:
+ * <ul>
+ * <li>For each node:
+ * <ul>
+ * <li><b>[int32]</b> the number of neighbor nodes
+ * <li><b>array[vint]</b> the neighbor ordinals, delta-encoded (initially subtracting
+ * -1)
+ * </ul>
+ * </ul>
* </ul>
*
* <h2>.vem (vector metadata) file</h2>
@@ -54,13 +83,14 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* <li><b>[int32]</b> vector similarity function ordinal
* <li><b>[vlong]</b> offset to this field's vectors in the .vec file
* <li><b>[vlong]</b> length of this field's vectors, in bytes
- * <li><b>[vlong]</b> offset to this field's index in the .vex file
- * <li><b>[vlong]</b> length of this field's index data, in bytes
+ * <li><b>[vlong]</b> offset to this field's graph index in the .vex file
+ * <li><b>[vlong]</b> length of this field's graph index data, in bytes
+ * <li><b>[vlong]</b> offset to this field's graph data in the .veg file
+ * <li><b>[vlong]</b> length of this field's graph data' data, in bytes
+ * <li><b>[int]</b> number of levels in the graph
* <li><b>[int]</b> dimension of this field's vectors
* <li><b>[int]</b> the number of documents having values for this field
* <li><b>array[vint]</b> the docids of documents having vectors, in order
- * <li><b>array[vlong]</b> for each document having a vector, the offset (delta-encoded relative
- * to the previous document) of its entry in the .vex file
* </ul>
*
* @lucene.experimental
@@ -69,10 +99,12 @@ public final class Lucene90HnswVectorsFormat extends KnnVectorsFormat {
static final String META_CODEC_NAME = "Lucene90HnswVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene90HnswVectorsFormatData";
- static final String VECTOR_INDEX_CODEC_NAME = "Lucene90HnswVectorsFormatIndex";
+ static final String GRAPH_INDEX_CODEC_NAME = "Lucene90HnswVectorsFormatGraphIndex";
+ static final String GRAPH_DATA_CODEC_NAME = "Lucene90HnswVectorsFormatGraphData";
static final String META_EXTENSION = "vem";
static final String VECTOR_DATA_EXTENSION = "vec";
- static final String VECTOR_INDEX_EXTENSION = "vex";
+ static final String GRAPH_INDEX_EXTENSION = "vex";
+ static final String GRAPH_DATA_EXTENSION = "veg";
static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START;
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
index 38e18db..e2aced6 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -37,6 +37,7 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
@@ -60,12 +61,12 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
- private final IndexInput vectorIndex;
+ private final IndexInput graphIndex;
+ private final IndexInput graphData;
private final long checksumSeed;
Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
this.fieldInfos = state.fieldInfos;
-
int versionMeta = readMetadata(state, Lucene90HnswVectorsFormat.META_EXTENSION);
long[] checksumRef = new long[1];
boolean success = false;
@@ -77,13 +78,22 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
Lucene90HnswVectorsFormat.VECTOR_DATA_EXTENSION,
Lucene90HnswVectorsFormat.VECTOR_DATA_CODEC_NAME,
checksumRef);
- vectorIndex =
+ graphIndex =
openDataInput(
state,
versionMeta,
- Lucene90HnswVectorsFormat.VECTOR_INDEX_EXTENSION,
- Lucene90HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME,
+ Lucene90HnswVectorsFormat.GRAPH_INDEX_EXTENSION,
+ Lucene90HnswVectorsFormat.GRAPH_INDEX_CODEC_NAME,
checksumRef);
+ graphData =
+ openDataInput(
+ state,
+ versionMeta,
+ Lucene90HnswVectorsFormat.GRAPH_DATA_EXTENSION,
+ Lucene90HnswVectorsFormat.GRAPH_DATA_CODEC_NAME,
+ checksumRef);
+ // TODO: should graph data be off-heap?
+ fillGraphNodesAndOffsetsByLevel();
success = true;
} finally {
if (success == false) {
@@ -204,6 +214,43 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return new FieldEntry(input, similarityFunction);
}
+ private void fillGraphNodesAndOffsetsByLevel() throws IOException {
+ for (FieldEntry entry : fields.values()) {
+ IndexInput input =
+ graphIndex.slice("graph-index", entry.graphIndexOffset, entry.graphIndexLength);
+ int numLevels = input.readInt();
+ assert entry.numLevels == numLevels;
+ int[] numNodesByLevel = new int[numLevels];
+
+ // read nodes by level
+ for (int level = 0; level < numLevels; level++) {
+ numNodesByLevel[level] = input.readInt();
+ if (level == 0) {
+ // we don't store nodes for level 0th, as this level contains all nodes
+ entry.nodesByLevel[0] = null;
+ } else {
+ final int[] nodesOnLevel = new int[numNodesByLevel[level]];
+ for (int i = 0; i < numNodesByLevel[level]; i++) {
+ nodesOnLevel[i] = input.readVInt();
+ }
+ entry.nodesByLevel[level] = nodesOnLevel;
+ }
+ }
+
+ // read offsets by level
+ long offset = 0;
+ for (int level = 0; level < numLevels; level++) {
+ assert numNodesByLevel[level] > 0;
+ long[] ordOffsets = new long[numNodesByLevel[level]];
+ for (int i = 0; i < ordOffsets.length; i++) {
+ offset += input.readVLong();
+ ordOffsets[i] = offset;
+ }
+ entry.ordOffsetsByLevel[level] = ordOffsets;
+ }
+ }
+ }
+
@Override
public long ramBytesUsed() {
long totalBytes = RamUsageEstimator.shallowSizeOfInstance(Lucene90HnswVectorsReader.class);
@@ -219,7 +266,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
@Override
public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorData);
- CodecUtil.checksumEntireFile(vectorIndex);
+ CodecUtil.checksumEntireFile(graphIndex);
+ CodecUtil.checksumEntireFile(graphData);
}
@Override
@@ -301,7 +349,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
throw new IllegalArgumentException("No such field '" + field + "'");
}
FieldEntry entry = fields.get(field);
- if (entry != null && entry.indexDataLength > 0) {
+ if (entry != null && entry.graphIndexLength > 0) {
return getGraphValues(entry);
} else {
return KnnGraphValues.EMPTY;
@@ -310,33 +358,39 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
IndexInput bytesSlice =
- vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
+ graphData.slice("graph-data", entry.graphDataOffset, entry.graphDataLength);
return new IndexedKnnGraphReader(entry, bytesSlice);
}
@Override
public void close() throws IOException {
- IOUtils.close(vectorData, vectorIndex);
+ IOUtils.close(vectorData, graphIndex, graphData);
}
private static class FieldEntry {
- final int dimension;
final VectorSimilarityFunction similarityFunction;
-
final long vectorDataOffset;
final long vectorDataLength;
- final long indexDataOffset;
- final long indexDataLength;
+ final long graphIndexOffset;
+ final long graphIndexLength;
+ final long graphDataOffset;
+ final long graphDataLength;
+ final int numLevels;
+ final int dimension;
final int[] ordToDoc;
- final long[] ordOffsets;
+ final int[][] nodesByLevel;
+ final long[][] ordOffsetsByLevel;
FieldEntry(DataInput input, VectorSimilarityFunction similarityFunction) throws IOException {
this.similarityFunction = similarityFunction;
vectorDataOffset = input.readVLong();
vectorDataLength = input.readVLong();
- indexDataOffset = input.readVLong();
- indexDataLength = input.readVLong();
+ graphIndexOffset = input.readVLong();
+ graphIndexLength = input.readVLong();
+ graphDataOffset = input.readVLong();
+ graphDataLength = input.readVLong();
+ numLevels = input.readInt();
dimension = input.readInt();
int size = input.readInt();
ordToDoc = new int[size];
@@ -344,12 +398,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
int doc = input.readVInt();
ordToDoc[i] = doc;
}
- ordOffsets = new long[size()];
- long offset = 0;
- for (int i = 0; i < ordOffsets.length; i++) {
- offset += input.readVLong();
- ordOffsets[i] = offset;
- }
+ nodesByLevel = new int[numLevels][];
+ ordOffsetsByLevel = new long[numLevels][];
}
int size() {
@@ -468,22 +518,39 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
/** Read the nearest-neighbors graph from the index input */
private static final class IndexedKnnGraphReader extends KnnGraphValues {
- final FieldEntry entry;
final IndexInput dataIn;
+ final int[][] nodesByLevel;
+ final long[][] ordOffsetsByLevel;
+ final int numLevels;
+ final int entryNode;
+ final int size;
int arcCount;
int arcUpTo;
int arc;
IndexedKnnGraphReader(FieldEntry entry, IndexInput dataIn) {
- this.entry = entry;
this.dataIn = dataIn;
+ this.nodesByLevel = entry.nodesByLevel;
+ this.ordOffsetsByLevel = entry.ordOffsetsByLevel;
+ this.numLevels = entry.numLevels;
+ this.entryNode = numLevels == 1 ? 0 : nodesByLevel[numLevels - 1][0];
+ this.size = entry.size();
}
@Override
public void seek(int level, int targetOrd) throws IOException {
+ long graphDataOffset;
+ if (level == 0) {
+ graphDataOffset = ordOffsetsByLevel[0][targetOrd];
+ } else {
+ int targetIndex =
+ Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
+ graphDataOffset = ordOffsetsByLevel[level][targetIndex];
+ }
+
// unsafe; no bounds checking
- dataIn.seek(entry.ordOffsets[targetOrd]);
+ dataIn.seek(graphDataOffset);
arcCount = dataIn.readInt();
arc = -1;
arcUpTo = 0;
@@ -491,7 +558,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
@Override
public int size() {
- return entry.size();
+ return size;
}
@Override
@@ -505,13 +572,78 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
@Override
- public int maxLevel() throws IOException {
- return 0;
+ public int numLevels() throws IOException {
+ return numLevels;
}
@Override
public int entryNode() throws IOException {
- return 0;
+ return entryNode;
+ }
+
+ @Override
+ public DocIdSetIterator getAllNodesOnLevel(int level) {
+ if (level == 0) {
+ return new DocIdSetIterator() {
+ int numNodes = size();
+ int idx = -1;
+
+ @Override
+ public int docID() {
+ return idx;
+ }
+
+ @Override
+ public int nextDoc() {
+ idx++;
+ if (idx >= numNodes) {
+ idx = NO_MORE_DOCS;
+ return NO_MORE_DOCS;
+ }
+ return idx;
+ }
+
+ @Override
+ public long cost() {
+ return numNodes;
+ }
+
+ @Override
+ public int advance(int target) {
+ throw new UnsupportedOperationException("Not supported");
+ }
+ };
+ } else {
+ return new DocIdSetIterator() {
+ final int[] nodes = nodesByLevel[level];
+ int idx = -1;
+
+ @Override
+ public int docID() {
+ return nodes[idx];
+ }
+
+ @Override
+ public int nextDoc() {
+ idx++;
+ if (idx >= nodes.length) {
+ idx = NO_MORE_DOCS;
+ return NO_MORE_DOCS;
+ }
+ return nodes[idx];
+ }
+
+ @Override
+ public long cost() {
+ return nodes.length;
+ }
+
+ @Override
+ public int advance(int target) {
+ throw new UnsupportedOperationException("Not supported");
+ }
+ };
+ }
}
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
index 40899b0..1bc55e2 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
@@ -29,6 +29,7 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
@@ -44,7 +45,7 @@ import org.apache.lucene.util.hnsw.NeighborArray;
public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
private final SegmentWriteState segmentWriteState;
- private final IndexOutput meta, vectorData, vectorIndex;
+ private final IndexOutput meta, vectorData, graphIndex, graphData;
private final int maxConn;
private final int beamWidth;
@@ -68,17 +69,24 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
state.segmentSuffix,
Lucene90HnswVectorsFormat.VECTOR_DATA_EXTENSION);
- String indexDataFileName =
+ String graphIndexFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
- Lucene90HnswVectorsFormat.VECTOR_INDEX_EXTENSION);
+ Lucene90HnswVectorsFormat.GRAPH_INDEX_EXTENSION);
+
+ String graphDataFileName =
+ IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ Lucene90HnswVectorsFormat.GRAPH_DATA_EXTENSION);
boolean success = false;
try {
meta = state.directory.createOutput(metaFileName, state.context);
vectorData = state.directory.createOutput(vectorDataFileName, state.context);
- vectorIndex = state.directory.createOutput(indexDataFileName, state.context);
+ graphIndex = state.directory.createOutput(graphIndexFileName, state.context);
+ graphData = state.directory.createOutput(graphDataFileName, state.context);
CodecUtil.writeIndexHeader(
meta,
@@ -93,8 +101,14 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
state.segmentInfo.getId(),
state.segmentSuffix);
CodecUtil.writeIndexHeader(
- vectorIndex,
- Lucene90HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME,
+ graphIndex,
+ Lucene90HnswVectorsFormat.GRAPH_INDEX_CODEC_NAME,
+ Lucene90HnswVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ CodecUtil.writeIndexHeader(
+ graphData,
+ Lucene90HnswVectorsFormat.GRAPH_DATA_CODEC_NAME,
Lucene90HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
@@ -124,42 +138,44 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
writeVectorValue(vectors);
docIds[count] = docV;
}
- // count may be < vectors.size() e,g, if some documents were deleted
- long[] offsets = new long[count];
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
- long vectorIndexOffset = vectorIndex.getFilePointer();
+
+ int numLevels;
+ long graphIndexOffset = graphIndex.getFilePointer();
+ long graphDataOffset = graphData.getFilePointer();
if (vectors instanceof RandomAccessVectorValuesProducer) {
- writeGraph(
- vectorIndex,
- (RandomAccessVectorValuesProducer) vectors,
- fieldInfo.getVectorSimilarityFunction(),
- vectorIndexOffset,
- offsets,
- count,
- maxConn,
- beamWidth);
+ numLevels =
+ writeGraph(
+ (RandomAccessVectorValuesProducer) vectors, fieldInfo.getVectorSimilarityFunction());
} else {
throw new IllegalArgumentException(
"Indexing an HNSW graph requires a random access vector values, got " + vectors);
}
- long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
+ long graphIndexLength = graphIndex.getFilePointer() - graphIndexOffset;
+ long graphDataLength = graphData.getFilePointer() - graphDataOffset;
+
writeMeta(
fieldInfo,
vectorDataOffset,
vectorDataLength,
- vectorIndexOffset,
- vectorIndexLength,
+ graphIndexOffset,
+ graphIndexLength,
+ graphDataOffset,
+ graphDataLength,
+ numLevels,
count,
docIds);
- writeGraphOffsets(meta, offsets);
}
private void writeMeta(
FieldInfo field,
long vectorDataOffset,
long vectorDataLength,
- long indexDataOffset,
- long indexDataLength,
+ long graphIndexOffset,
+ long graphIndexLength,
+ long graphDataOffset,
+ long graphDataLength,
+ int numLevels,
int size,
int[] docIds)
throws IOException {
@@ -167,8 +183,11 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
meta.writeVLong(vectorDataOffset);
meta.writeVLong(vectorDataLength);
- meta.writeVLong(indexDataOffset);
- meta.writeVLong(indexDataLength);
+ meta.writeVLong(graphIndexOffset);
+ meta.writeVLong(graphIndexLength);
+ meta.writeVLong(graphDataOffset);
+ meta.writeVLong(graphDataLength);
+ meta.writeInt(numLevels);
meta.writeInt(field.getVectorDimension());
meta.writeInt(size);
for (int i = 0; i < size; i++) {
@@ -184,52 +203,64 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
}
- private void writeGraphOffsets(IndexOutput out, long[] offsets) throws IOException {
- long last = 0;
- for (long offset : offsets) {
- out.writeVLong(offset - last);
- last = offset;
- }
- }
-
- private void writeGraph(
- IndexOutput graphData,
- RandomAccessVectorValuesProducer vectorValues,
- VectorSimilarityFunction similarityFunction,
- long graphDataOffset,
- long[] offsets,
- int count,
- int maxConn,
- int beamWidth)
+ private int writeGraph(
+ RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
+
+ // build graph
HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(
- vectorValues, similarityFunction, maxConn, beamWidth, 0, HnswGraphBuilder.randSeed);
+ vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
- // TODO: implement storing of hierarchical graph; for now stores only 0th level
- for (int ord = 0; ord < count; ord++) {
- // write graph
- offsets[ord] = graphData.getFilePointer() - graphDataOffset;
-
- NeighborArray neighbors = graph.getNeighbors(0, ord);
- int size = neighbors.size();
+ graphIndex.writeInt(graph.numLevels()); // number of levels
+ for (int level = 0; level < graph.numLevels(); level++) {
+ DocIdSetIterator nodesOnLevel = graph.getAllNodesOnLevel(level);
+ int countOnLevel = (int) nodesOnLevel.cost();
+ // write graph nodes on the level into the graphIndex file
+ graphIndex.writeInt(countOnLevel); // number of nodes on a level
+ if (level > 0) {
+ for (int node = nodesOnLevel.nextDoc();
+ node != DocIdSetIterator.NO_MORE_DOCS;
+ node = nodesOnLevel.nextDoc()) {
+ graphIndex.writeVInt(node); // list of nodes on a level
+ }
+ }
+ }
- // Destructively modify; it's ok we are discarding it after this
- int[] nodes = neighbors.node();
- Arrays.sort(nodes, 0, size);
- graphData.writeInt(size);
+ long lastOffset = 0;
+ int countOnLevel0 = graph.size();
+ long graphDataOffset = graphData.getFilePointer();
+ for (int level = 0; level < graph.numLevels(); level++) {
+ DocIdSetIterator nodesOnLevel = graph.getAllNodesOnLevel(level);
+ for (int node = nodesOnLevel.nextDoc();
+ node != DocIdSetIterator.NO_MORE_DOCS;
+ node = nodesOnLevel.nextDoc()) {
+ // write graph offsets on the level into the graphIndex file
+ long offset = graphData.getFilePointer() - graphDataOffset;
+ graphIndex.writeVLong(offset - lastOffset);
+ lastOffset = offset;
- int lastNode = -1; // to make the assertion work?
- for (int i = 0; i < size; i++) {
- int node = nodes[i];
- assert node > lastNode : "nodes out of order: " + lastNode + "," + node;
- assert node < offsets.length : "node too large: " + node + ">=" + offsets.length;
- graphData.writeVInt(node - lastNode);
- lastNode = node;
+ // write neighbours on the level into the graphData file
+ NeighborArray neighbors = graph.getNeighbors(level, node);
+ int size = neighbors.size();
+ graphData.writeInt(size);
+ // Destructively modify; it's ok we are discarding it after this
+ int[] nnodes = neighbors.node();
+ Arrays.sort(nnodes, 0, size);
+ int lastNode = -1; // to make the assertion work?
+ for (int i = 0; i < size; i++) {
+ int nnode = nnodes[i];
+ assert nnode > lastNode : "nodes out of order: " + lastNode + "," + nnode;
+ assert nnode < countOnLevel0 : "node too large: " + nnode + ">=" + countOnLevel0;
+ graphData.writeVInt(nnode - lastNode);
+ lastNode = nnode;
+ }
}
}
+
+ return graph.numLevels();
}
@Override
@@ -246,12 +277,13 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
}
if (vectorData != null) {
CodecUtil.writeFooter(vectorData);
- CodecUtil.writeFooter(vectorIndex);
+ CodecUtil.writeFooter(graphIndex);
+ CodecUtil.writeFooter(graphData);
}
}
@Override
public void close() throws IOException {
- IOUtils.close(meta, vectorData, vectorIndex);
+ IOUtils.close(meta, vectorData, graphIndex, graphData);
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java
index 11d1289..2ce9b4b 100644
--- a/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java
+++ b/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java
@@ -20,6 +20,7 @@ package org.apache.lucene.index;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
+import org.apache.lucene.search.DocIdSetIterator;
/**
* Access to per-document neighbor lists in a (hierarchical) knn search graph.
@@ -52,12 +53,21 @@ public abstract class KnnGraphValues {
*/
public abstract int nextNeighbor() throws IOException;
- /** Returns top level of the graph * */
- public abstract int maxLevel() throws IOException;
+ /** Returns the number of levels of the graph */
+ public abstract int numLevels() throws IOException;
/** Returns graph's entry point on the top level * */
public abstract int entryNode() throws IOException;
+ /**
+ * Get all nodes on a given level as node 0th ordinals
+ *
+ * @param level level for which to get all nodes
+ * @return an iterator over nodes where {@code nextDoc} returns a next node ordinal
+ */
+ // TODO: return a more suitable iterator over nodes than DocIdSetIterator
+ public abstract DocIdSetIterator getAllNodesOnLevel(int level) throws IOException;
+
/** Empty graph value */
public static KnnGraphValues EMPTY =
new KnnGraphValues() {
@@ -76,7 +86,7 @@ public abstract class KnnGraphValues {
}
@Override
- public int maxLevel() {
+ public int numLevels() {
return 0;
}
@@ -84,5 +94,10 @@ public abstract class KnnGraphValues {
public int entryNode() {
return 0;
}
+
+ @Override
+ public DocIdSetIterator getAllNodesOnLevel(int level) {
+ return DocIdSetIterator.empty();
+ }
};
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
index 3598b9c..c3ffbb1 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
@@ -59,7 +59,7 @@ import org.apache.lucene.util.SparseFixedBitSet;
public final class HnswGraph extends KnnGraphValues {
private final int maxConn;
- private int curMaxLevel; // the current max graph level
+ private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level
// Nodes by level expressed as the level 0's nodes' ordinals.
@@ -79,10 +79,10 @@ public final class HnswGraph extends KnnGraphValues {
HnswGraph(int maxConn, int levelOfFirstNode) {
this.maxConn = maxConn;
- this.curMaxLevel = levelOfFirstNode;
- this.graph = new ArrayList<>(curMaxLevel + 1);
+ this.numLevels = levelOfFirstNode + 1;
+ this.graph = new ArrayList<>(numLevels);
this.entryNode = 0;
- for (int i = 0; i <= curMaxLevel; i++) {
+ for (int i = 0; i < numLevels; i++) {
graph.add(new ArrayList<>());
// Typically with diversity criteria we see nodes not fully occupied;
// average fanout seems to be about 1/2 maxConn.
@@ -90,9 +90,9 @@ public final class HnswGraph extends KnnGraphValues {
graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
}
- this.nodesByLevel = new ArrayList<>(curMaxLevel + 1);
- nodesByLevel.add(null); // we don't need this for 0th level, as it contians all nodes
- for (int l = 1; l <= curMaxLevel; l++) {
+ this.nodesByLevel = new ArrayList<>(numLevels);
+ nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes
+ for (int l = 1; l < numLevels; l++) {
nodesByLevel.add(new int[] {0});
}
}
@@ -121,35 +121,25 @@ public final class HnswGraph extends KnnGraphValues {
Bits acceptOrds,
Random random)
throws IOException {
+
int size = graphValues.size();
- int boundedNumSeed = Math.min(numSeed, 2 * size);
+ int boundedNumSeed = Math.max(topK, Math.min(numSeed, 2 * size));
NeighborQueue results;
- if (graphValues.maxLevel() == 0) {
- // search in NSW; generate a number of entry points randomly
- final int[] eps = new int[boundedNumSeed];
- for (int i = 0; i < boundedNumSeed; i++) {
- eps[i] = random.nextInt(size);
- }
- return searchLevel(query, topK, 0, eps, vectors, similarityFunction, graphValues, acceptOrds);
- } else {
- // search in hierarchical NSW
- int[] eps = new int[] {graphValues.entryNode()};
- for (int level = graphValues.maxLevel(); level >= 1; level--) {
- results =
- HnswGraph.searchLevel(
- query, 1, level, eps, vectors, similarityFunction, graphValues, null);
- eps[0] = results.pop();
- }
- boundedNumSeed = Math.max(topK, boundedNumSeed);
+ int[] eps = new int[] {graphValues.entryNode()};
+ for (int level = graphValues.numLevels() - 1; level >= 1; level--) {
results =
HnswGraph.searchLevel(
- query, boundedNumSeed, 0, eps, vectors, similarityFunction, graphValues, acceptOrds);
- while (results.size() > topK) {
- results.pop();
- }
- return results;
+ query, 1, level, eps, vectors, similarityFunction, graphValues, null);
+ eps[0] = results.pop();
}
+ results =
+ HnswGraph.searchLevel(
+ query, boundedNumSeed, 0, eps, vectors, similarityFunction, graphValues, acceptOrds);
+ while (results.size() > topK) {
+ results.pop();
+ }
+ return results;
}
/**
@@ -266,12 +256,12 @@ public final class HnswGraph extends KnnGraphValues {
if (level > 0) {
// if the new node introduces a new level, add more levels to the graph,
// and make this node the graph's new entry point
- if (level > curMaxLevel) {
- for (int i = curMaxLevel + 1; i <= level; i++) {
+ if (level >= numLevels) {
+ for (int i = numLevels; i <= level; i++) {
graph.add(new ArrayList<>());
nodesByLevel.add(new int[] {node});
}
- curMaxLevel = level;
+ numLevels = level + 1;
entryNode = node;
} else {
// Add this node id to this level's nodes
@@ -305,13 +295,13 @@ public final class HnswGraph extends KnnGraphValues {
}
/**
- * Returns the current top level of the graph
+ * Returns the current number of levels in the graph
*
- * @return current maximum level of the graph
+ * @return the current number of levels in the graph
*/
@Override
- public int maxLevel() {
- return curMaxLevel;
+ public int numLevels() {
+ return numLevels;
}
/**
@@ -325,13 +315,7 @@ public final class HnswGraph extends KnnGraphValues {
return entryNode;
}
- /**
- * Get all nodes on a given level as node 0th ordinals
- *
- * @param level level for which to get all nodes
- * @return an iterator over nodes where {@code nextDoc} returns a next node
- */
- // TODO: return a more suitable iterator over nodes than DocIdSetIterator
+ @Override
public DocIdSetIterator getAllNodesOnLevel(int level) {
return new DocIdSetIterator() {
int[] nodes = level == 0 ? null : nodesByLevel.get(level);
@@ -354,12 +338,12 @@ public final class HnswGraph extends KnnGraphValues {
}
@Override
- public int advance(int target) {
- throw new UnsupportedOperationException("Not supported");
+ public long cost() {
+ return size;
}
@Override
- public long cost() {
+ public int advance(int target) {
throw new UnsupportedOperationException("Not supported");
}
};
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index f6beac0..f21c8f7 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -67,7 +67,6 @@ public final class HnswGraphBuilder {
* @param maxConn the number of connections to make when adding a new graph node; roughly speaking
* the graph fanout.
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
- * @param ml normalization factor for level generation
* @param seed the seed for a random number generator used during graph construction. Provide this
* to ensure repeatable construction.
*/
@@ -76,7 +75,6 @@ public final class HnswGraphBuilder {
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
- double ml,
long seed) {
vectorValues = vectors.randomAccess();
buildVectors = vectors.randomAccess();
@@ -89,15 +87,11 @@ public final class HnswGraphBuilder {
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
- this.ml = ml;
+ // normalization factor for level generation; currently not configurable
+ this.ml = 1 / Math.log(1.0 * maxConn);
this.random = new Random(seed);
-
- if (ml == 0) {
- this.hnsw = new HnswGraph(maxConn, 0);
- } else {
- int levelOfFirstNode = getRandomGraphLevel(ml, random);
- this.hnsw = new HnswGraph(maxConn, levelOfFirstNode);
- }
+ int levelOfFirstNode = getRandomGraphLevel(ml, random);
+ this.hnsw = new HnswGraph(maxConn, levelOfFirstNode);
bound = BoundsChecker.create(similarityFunction.reversed);
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
}
@@ -118,20 +112,6 @@ public final class HnswGraphBuilder {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
}
- if (ml == 0) {
- buildNSW(vectors);
- } else {
- buildHNSW(vectors);
- }
- return hnsw;
- }
-
- public void setInfoStream(InfoStream infoStream) {
- this.infoStream = infoStream;
- }
-
- // build navigable small world graph (single-layered)
- private void buildNSW(RandomAccessVectorValues vectors) throws IOException {
long start = System.nanoTime(), t = start;
// start at node 1! node 0 is added implicitly, in the constructor
for (int node = 1; node < vectors.size(); node++) {
@@ -140,41 +120,18 @@ public final class HnswGraphBuilder {
t = printGraphBuildStatus(node, start, t);
}
}
+ return hnsw;
}
- /** Inserts a doc with vector value to the graph */
- void addGraphNode(int node, float[] value) throws IOException {
- // We pass 'null' for acceptOrds because there are no deletions while building the graph
- NeighborQueue candidates =
- HnswGraph.search(
- value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
-
- hnsw.addNode(0, node);
-
- /* connect neighbors to the new node, using a diversity heuristic that chooses successive
- * nearest neighbors that are closer to the new node than they are to the previously-selected
- * neighbors
- */
- addDiverseNeighbors(0, node, candidates);
- }
-
- // build hierarchical navigable small world graph (multi-layered)
- void buildHNSW(RandomAccessVectorValues vectors) throws IOException {
- long start = System.nanoTime(), t = start;
- // start at node 1! node 0 is added implicitly, in the constructor
- for (int node = 1; node < vectors.size(); node++) {
- addGraphNodeHNSW(node, vectors.vectorValue(node));
- if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
- t = printGraphBuildStatus(node, start, t);
- }
- }
+ public void setInfoStream(InfoStream infoStream) {
+ this.infoStream = infoStream;
}
/** Inserts a doc with vector value to the graph */
- void addGraphNodeHNSW(int node, float[] value) throws IOException {
+ void addGraphNode(int node, float[] value) throws IOException {
NeighborQueue candidates;
final int nodeLevel = getRandomGraphLevel(ml, random);
- int curMaxLevel = hnsw.maxLevel();
+ int curMaxLevel = hnsw.numLevels() - 1;
int[] eps = new int[] {hnsw.entryNode()};
// if a node introduces new levels to the graph, add this new node on new levels
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index a67b6d1..a29e26b 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -16,6 +16,7 @@
*/
package org.apache.lucene.index;
+import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.randSeed;
@@ -26,6 +27,7 @@ import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
+import java.util.concurrent.CountDownLatch;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
@@ -38,12 +40,18 @@ import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.StringField;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.SearcherFactory;
+import org.apache.lucene.search.SearcherManager;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
@@ -153,21 +161,56 @@ public class TestKnnGraph extends LuceneTestCase {
int dimension = atLeast(10);
float[][] values = randomVectors(numDoc, dimension);
int mergePoint = random().nextInt(numDoc);
- int[][] mergedGraph = getIndexedGraph(values, mergePoint, seed);
- int[][] singleSegmentGraph = getIndexedGraph(values, -1, seed);
+ int[][][] mergedGraph = getIndexedGraph(values, mergePoint, seed);
+ int[][][] singleSegmentGraph = getIndexedGraph(values, -1, seed);
assertGraphEquals(singleSegmentGraph, mergedGraph);
}
- private void assertGraphEquals(int[][] expected, int[][] actual) {
+ /** Test writing and reading of multiple vector fields * */
+ public void testMultipleVectorFields() throws Exception {
+ int numVectorFields = randomIntBetween(2, 5);
+ int numDoc = atLeast(100);
+ int[] dims = new int[numVectorFields];
+ float[][][] values = new float[numVectorFields][][];
+ for (int field = 0; field < numVectorFields; field++) {
+ dims[field] = atLeast(3);
+ values[field] = randomVectors(numDoc, dims[field]);
+ }
+
+ try (Directory dir = newDirectory();
+ IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) {
+ for (int docID = 0; docID < numDoc; docID++) {
+ Document doc = new Document();
+ for (int field = 0; field < numVectorFields; field++) {
+ float[] vector = values[field][docID];
+ if (vector != null) {
+ FieldType fieldType = KnnVectorField.createFieldType(vector.length, similarityFunction);
+ doc.add(new KnnVectorField(KNN_GRAPH_FIELD + field, vector, fieldType));
+ }
+ }
+ String idString = Integer.toString(docID);
+ doc.add(new StringField("id", idString, Field.Store.YES));
+ iw.addDocument(doc);
+ }
+ for (int field = 0; field < numVectorFields; field++) {
+ assertConsistentGraph(iw, values[field], KNN_GRAPH_FIELD + field);
+ }
+ }
+ }
+
+ private void assertGraphEquals(int[][][] expected, int[][][] actual) {
assertEquals("graph sizes differ", expected.length, actual.length);
- for (int i = 0; i < expected.length; i++) {
- assertArrayEquals("difference at ord=" + i, expected[i], actual[i]);
+ for (int level = 0; level < expected.length; level++) {
+ for (int node = 0; node < expected[level].length; node++) {
+ assertArrayEquals("difference at ord=" + node, expected[level][node], actual[level][node]);
+ }
}
}
- private int[][] getIndexedGraph(float[][] values, int mergePoint, long seed) throws IOException {
+ private int[][][] getIndexedGraph(float[][] values, int mergePoint, long seed)
+ throws IOException {
HnswGraphBuilder.randSeed = seed;
- int[][] graph;
+ int[][][] graph;
try (Directory dir = newDirectory()) {
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setMergePolicy(new LogDocMergePolicy()); // for predictable segment ordering when merging
@@ -208,18 +251,24 @@ public class TestKnnGraph extends LuceneTestCase {
return values;
}
- int[][] copyGraph(KnnGraphValues values) throws IOException {
- int size = values.size();
- int[][] graph = new int[size][];
+ int[][][] copyGraph(KnnGraphValues graphValues) throws IOException {
+ int[][][] graph = new int[graphValues.numLevels()][][];
+ int size = graphValues.size();
int[] scratch = new int[maxConn];
- for (int node = 0; node < size; node++) {
- int n, count = 0;
- values.seek(0, node);
- while ((n = values.nextNeighbor()) != NO_MORE_DOCS) {
- scratch[count++] = n;
- // graph[node][i++] = n;
+
+ for (int level = 0; level < graphValues.numLevels(); level++) {
+ DocIdSetIterator nodesItr = graphValues.getAllNodesOnLevel(level);
+ graph[level] = new int[size][];
+ for (int node = nodesItr.nextDoc();
+ node != DocIdSetIterator.NO_MORE_DOCS;
+ node = nodesItr.nextDoc()) {
+ graphValues.seek(level, node);
+ int n, count = 0;
+ while ((n = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
+ scratch[count++] = n;
+ }
+ graph[level][node] = ArrayUtil.copyOfSubArray(scratch, 0, count);
}
- graph[node] = ArrayUtil.copyOfSubArray(scratch, 0, count);
}
return graph;
}
@@ -232,31 +281,7 @@ public class TestKnnGraph extends LuceneTestCase {
config.setCodec(codec); // test is not compatible with simpletext
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, config)) {
- // Add a document for every cartesian point in an NxN square so we can
- // easily know which are the nearest neighbors to every point. Insert by iterating
- // using a prime number that is not a divisor of N*N so that we will hit each point once,
- // and chosen so that points will be inserted in a deterministic
- // but somewhat distributed pattern
- int n = 5, stepSize = 17;
- float[][] values = new float[n * n][];
- int index = 0;
- for (int i = 0; i < values.length; i++) {
- // System.out.printf("%d: (%d, %d)\n", i, index % n, index / n);
- int x = index % n, y = index / n;
- values[i] = new float[] {x, y};
- index = (index + stepSize) % (n * n);
- add(iw, i, values[i]);
- if (i == 13) {
- // create 2 segments
- iw.commit();
- }
- }
- boolean forceMerge = random().nextBoolean();
- // System.out.println("");
- if (forceMerge) {
- iw.forceMerge(1);
- }
- assertConsistentGraph(iw, values);
+ indexData(iw);
try (DirectoryReader dr = DirectoryReader.open(iw)) {
// results are ordered by score (descending) and docid (ascending);
// This is the insertion order:
@@ -279,6 +304,77 @@ public class TestKnnGraph extends LuceneTestCase {
}
}
+ private void indexData(IndexWriter iw) throws IOException {
+ // Add a document for every cartesian point in an NxN square so we can
+ // easily know which are the nearest neighbors to every point. Insert by iterating
+ // using a prime number that is not a divisor of N*N so that we will hit each point once,
+ // and chosen so that points will be inserted in a deterministic
+ // but somewhat distributed pattern
+ int n = 5, stepSize = 17;
+ float[][] values = new float[n * n][];
+ int index = 0;
+ for (int i = 0; i < values.length; i++) {
+ // System.out.printf("%d: (%d, %d)\n", i, index % n, index / n);
+ int x = index % n, y = index / n;
+ values[i] = new float[] {x, y};
+ index = (index + stepSize) % (n * n);
+ add(iw, i, values[i]);
+ if (i == 13) {
+ // create 2 segments
+ iw.commit();
+ }
+ }
+ boolean forceMerge = random().nextBoolean();
+ if (forceMerge) {
+ iw.forceMerge(1);
+ }
+ assertConsistentGraph(iw, values);
+ }
+
+ public void testMultiThreadedSearch() throws Exception {
+ similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+ IndexWriterConfig config = newIndexWriterConfig();
+ config.setCodec(codec);
+ Directory dir = newDirectory();
+ IndexWriter iw = new IndexWriter(dir, config);
+ indexData(iw);
+
+ final SearcherManager manager = new SearcherManager(iw, new SearcherFactory());
+ Thread[] threads = new Thread[randomIntBetween(2, 5)];
+ final CountDownLatch latch = new CountDownLatch(1);
+ for (int i = 0; i < threads.length; i++) {
+ threads[i] =
+ new Thread(
+ () -> {
+ try {
+ latch.await();
+ IndexSearcher searcher = manager.acquire();
+ try {
+ KnnVectorQuery query = new KnnVectorQuery("vector", new float[] {0f, 0.1f}, 5);
+ TopDocs results = searcher.search(query, 5);
+ for (ScoreDoc doc : results.scoreDocs) {
+ // map docId to insertion id
+ doc.doc =
+ Integer.parseInt(searcher.getIndexReader().document(doc.doc).get("id"));
+ }
+ assertResults(new int[] {0, 15, 3, 18, 5}, results);
+ } finally {
+ manager.release(searcher);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ threads[i].start();
+ }
+
+ latch.countDown();
+ for (Thread t : threads) {
+ t.join();
+ }
+ IOUtils.close(manager, iw, dir);
+ }
+
private void assertGraphSearch(int[] expected, float[] vector, IndexReader reader)
throws IOException {
TopDocs results = doKnnSearch(reader, vector, 5);
@@ -310,39 +406,40 @@ public class TestKnnGraph extends LuceneTestCase {
}
}
+ private void assertConsistentGraph(IndexWriter iw, float[][] values) throws IOException {
+ assertConsistentGraph(iw, values, KNN_GRAPH_FIELD);
+ }
+
// For each leaf, verify that its graph nodes are 1-1 with vectors, that the vectors are the
- // expected values,
- // and that the graph is fully connected and symmetric.
+ // expected values, and that the graph is fully connected and symmetric.
// NOTE: when we impose max-fanout on the graph it wil no longer be symmetric, but should still
// be fully connected. Is there any other invariant we can test? Well, we can check that max
- // fanout
- // is respected. We can test *desirable* properties of the graph like small-world (the graph
- // diameter
- // should be tightly bounded).
- private void assertConsistentGraph(IndexWriter iw, float[][] values) throws IOException {
- int totalGraphDocs = 0;
+ // fanout is respected. We can test *desirable* properties of the graph like small-world
+ // (the graph diameter should be tightly bounded).
+ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vectorField)
+ throws IOException {
+ int numDocsWithVectors = 0;
try (DirectoryReader dr = DirectoryReader.open(iw)) {
for (LeafReaderContext ctx : dr.leaves()) {
LeafReader reader = ctx.reader();
- VectorValues vectorValues = reader.getVectorValues(KNN_GRAPH_FIELD);
+ VectorValues vectorValues = reader.getVectorValues(vectorField);
PerFieldKnnVectorsFormat.FieldsReader perFieldReader =
(PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) reader).getVectorReader();
if (perFieldReader == null) {
continue;
}
Lucene90HnswVectorsReader vectorReader =
- (Lucene90HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
- KnnGraphValues graphValues = vectorReader.getGraphValues(KNN_GRAPH_FIELD);
- assertEquals((vectorValues == null), (graphValues == null));
+ (Lucene90HnswVectorsReader) perFieldReader.getFieldReader(vectorField);
+ KnnGraphValues graphValues = vectorReader.getGraphValues(vectorField);
if (vectorValues == null) {
+ assert graphValues == null;
continue;
}
- int[][] graph = new int[reader.maxDoc()][];
- boolean foundOrphan = false;
- int graphSize = 0;
+
+ // assert vector values:
+ // stored vector values are the same as original
for (int i = 0; i < reader.maxDoc(); i++) {
int nextDocWithVectors = vectorValues.advance(i);
- // System.out.println("advanced to " + nextDocWithVectors);
while (i < nextDocWithVectors && i < reader.maxDoc()) {
int id = Integer.parseInt(reader.document(i).get("id"));
assertNull("document " + id + " has no vector, but was expected to", values[id]);
@@ -352,7 +449,6 @@ public class TestKnnGraph extends LuceneTestCase {
break;
}
int id = Integer.parseInt(reader.document(i).get("id"));
- graphValues.seek(0, graphSize);
// documents with KnnGraphValues have the expected vectors
float[] scratch = vectorValues.vectorValue();
assertArrayEquals(
@@ -360,51 +456,69 @@ public class TestKnnGraph extends LuceneTestCase {
values[id],
scratch,
0f);
- // We collect neighbors for analysis below
- List<Integer> friends = new ArrayList<>();
- int arc;
- while ((arc = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
- friends.add(arc);
+ numDocsWithVectors++;
+ }
+ assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
+
+ // assert graph values:
+ // For each level of the graph assert that:
+ // 1. There are no orphan nodes without any friends
+ // 2. If orphans are found, than the level must contain only 0 or a single node
+ // 3. If the number of nodes on the level doesn't exceed maxConn, assert that the graph is
+ // fully connected, i.e. any node is reachable from any other node.
+ // 4. If the number of nodes on the level exceeds maxConn, assert that maxConn is respected.
+ for (int level = 0; level < graphValues.numLevels(); level++) {
+ int[][] graphOnLevel = new int[graphValues.size()][];
+ int countOnLevel = 0;
+ boolean foundOrphan = false;
+ DocIdSetIterator nodesItr = graphValues.getAllNodesOnLevel(level);
+ for (int node = nodesItr.nextDoc();
+ node != DocIdSetIterator.NO_MORE_DOCS;
+ node = nodesItr.nextDoc()) {
+ graphValues.seek(level, node);
+ int arc;
+ List<Integer> friends = new ArrayList<>();
+ while ((arc = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
+ friends.add(arc);
+ }
+ if (friends.size() == 0) {
+ foundOrphan = true;
+ } else {
+ int[] friendsCopy = new int[friends.size()];
+ Arrays.setAll(friendsCopy, friends::get);
+ graphOnLevel[node] = friendsCopy;
+ }
+ countOnLevel++;
}
- if (friends.size() == 0) {
- // System.out.printf("knngraph @%d is singleton (advance returns %d)\n", i,
- // nextWithNeighbors);
- foundOrphan = true;
+ // System.out.println("Level[" + level + "] has [" + nodesCount + "] nodes.");
+ assertEquals(nodesItr.cost(), countOnLevel);
+ assertFalse("No nodes on level [" + level + "]", countOnLevel == 0);
+ if (countOnLevel == 1) {
+ assertTrue(
+ "Graph with 1 node has unexpected neighbors on level [" + level + "]", foundOrphan);
} else {
- // NOTE: these friends are dense ordinals, not docIds.
- int[] friendCopy = new int[friends.size()];
- for (int j = 0; j < friends.size(); j++) {
- friendCopy[j] = friends.get(j);
+ assertFalse(
+ "Graph has orphan nodes with no friends on level [" + level + "]", foundOrphan);
+ if (maxConn > countOnLevel) {
+ // assert that the graph is fully connected,
+ // i.e. any node can be reached from any other node
+ assertConnected(graphOnLevel);
+ } else {
+ // assert that max-connections was respected
+ assertMaxConn(graphOnLevel, maxConn);
}
- graph[graphSize] = friendCopy;
- // System.out.printf("knngraph @%d => %s\n", i, Arrays.toString(graph[i]));
}
- graphSize++;
}
- assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
- if (foundOrphan) {
- assertEquals("graph is not fully connected", 1, graphSize);
- } else {
- assertTrue(
- "Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1);
- }
- if (maxConn > graphSize) {
- // assert that the graph in each leaf is connected
- assertConnected(graph);
- } else {
- // assert that max-connections was respected
- assertMaxConn(graph, maxConn);
- }
- totalGraphDocs += graphSize;
}
}
- int expectedCount = 0;
- for (float[] friends : values) {
- if (friends != null) {
- ++expectedCount;
+
+ int expectedNumDocsWithVectors = 0;
+ for (float[] value : values) {
+ if (value != null) {
+ ++expectedNumDocsWithVectors;
}
}
- assertEquals(expectedCount, totalGraphDocs);
+ assertEquals(expectedNumDocsWithVectors, numDocsWithVectors);
}
public static void assertMaxConn(int[][] graph, int maxConn) {
@@ -418,37 +532,36 @@ public class TestKnnGraph extends LuceneTestCase {
}
}
+ /** Assert that every node is reachable from some other node */
private static void assertConnected(int[][] graph) {
- // every node in the graph is reachable from every other node
+ List<Integer> nodes = new ArrayList<>();
Set<Integer> visited = new HashSet<>();
List<Integer> queue = new LinkedList<>();
- int count = 0;
- for (int[] entry : graph) {
- if (entry != null) {
- if (queue.isEmpty()) {
- queue.add(entry[0]); // start from any node
- // System.out.println("start at " + entry[0]);
- }
- ++count;
+ for (int i = 0; i < graph.length; i++) {
+ if (graph[i] != null) {
+ nodes.add(i);
}
}
+
+ // start from any node
+ int startIdx = random().nextInt(nodes.size());
+ queue.add(nodes.get(startIdx));
while (queue.isEmpty() == false) {
int i = queue.remove(0);
assertNotNull("expected neighbors of " + i, graph[i]);
visited.add(i);
for (int j : graph[i]) {
if (visited.contains(j) == false) {
- // System.out.println(" ... " + j);
queue.add(j);
}
}
}
- for (int i = 0; i < count; i++) {
- assertTrue("Attempted to walk entire graph but never visited " + i, visited.contains(i));
+ // assert that every node is reachable from some other node as it was visited
+ for (int node : nodes) {
+ assertTrue(
+ "Attempted to walk entire graph but never visited node [" + node + "]",
+ visited.contains(node));
}
- // we visited each node exactly once
- assertEquals(
- "Attempted to walk entire graph but only visited " + visited.size(), count, visited.size());
}
private void add(IndexWriter iw, int id, float[] vector) throws IOException {
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index bff69a0..cc977c5 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -253,7 +253,7 @@ public class KnnGraphTester {
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
RandomAccessVectorValues values = vectors.randomAccess();
HnswGraphBuilder builder =
- new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0, 0);
+ new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0);
// start at node 1
for (int i = 1; i < numDocs; i++) {
builder.addGraphNode(i, values.vectorValue(i));
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHNSWGraph2.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHNSWGraph2.java
deleted file mode 100644
index 2605ca8..0000000
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHNSWGraph2.java
+++ /dev/null
@@ -1,161 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.lucene.util.hnsw;
-
-import static org.apache.lucene.index.TestKnnGraph.assertMaxConn;
-import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Random;
-import java.util.Set;
-import org.apache.lucene.index.VectorSimilarityFunction;
-import org.apache.lucene.search.DocIdSetIterator;
-import org.apache.lucene.util.LuceneTestCase;
-import org.apache.lucene.util.VectorUtil;
-
-public class TestHNSWGraph2 extends LuceneTestCase {
-
- // Tests that graph is consistent.
- public void testGraphConsistent() throws IOException {
- int dim = random().nextInt(100) + 1;
- int nDoc = random().nextInt(100) + 1;
- MockVectorValues values = new MockVectorValues(createRandomVectors(nDoc, dim, random()));
- int beamWidth = random().nextInt(10) + 5;
- int maxConn = random().nextInt(10) + 5;
- double ml = 1 / Math.log(1.0 * maxConn);
- long seed = random().nextLong();
- VectorSimilarityFunction similarityFunction =
- VectorSimilarityFunction.values()[
- random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
- HnswGraphBuilder builder =
- new HnswGraphBuilder(values, similarityFunction, maxConn, beamWidth, ml, seed);
- HnswGraph hnsw = builder.build(values);
- assertConsistentGraph(hnsw, maxConn);
- }
-
- /**
- * For each level of the graph, test that
- *
- * <p>1. There are no orphan nodes without any friends
- *
- * <p>2. If orphans are found, than the level must contain only 0 or a single node
- *
- * <p>3. If the number of nodes on the level doesn't exceed maxConn, assert that the graph is
- * fully connected, i.e. any node is reachable from any other node.
- *
- * <p>4. If the number of nodes on the level exceeds maxConn, assert that maxConn is respected.
- *
- * <p>copy from TestKnnGraph::assertConsistentGraph with parts relevant only to in-memory graphs
- * TODO: remove when hierarchical graph is implemented on disk
- */
- private static void assertConsistentGraph(HnswGraph hnsw, int maxConn) throws IOException {
- for (int level = hnsw.maxLevel(); level >= 0; level--) {
- int[][] graph = new int[hnsw.size()][];
- int nodesCount = 0;
- boolean foundOrphan = false;
-
- DocIdSetIterator nodesItr = hnsw.getAllNodesOnLevel(level);
- for (int node = nodesItr.nextDoc();
- node != DocIdSetIterator.NO_MORE_DOCS;
- node = nodesItr.nextDoc()) {
- hnsw.seek(level, node);
- int arc;
- List<Integer> friends = new ArrayList<>();
- while ((arc = hnsw.nextNeighbor()) != NO_MORE_DOCS) {
- friends.add(arc);
- }
- if (friends.size() == 0) {
- foundOrphan = true;
- } else {
- int[] friendsCopy = new int[friends.size()];
- for (int f = 0; f < friends.size(); f++) {
- friendsCopy[f] = friends.get(f);
- }
- graph[node] = friendsCopy;
- }
- nodesCount++;
- }
- // System.out.println("Level[" + level + "] has [" + nodesCount + "] nodes.");
-
- assertFalse("No nodes on level [" + level + "]", nodesCount == 0);
- if (nodesCount == 1) {
- assertTrue(
- "Graph with 1 node has unexpected neighbors on level [" + level + "]", foundOrphan);
- } else {
- assertFalse("Graph has orphan nodes with no friends on level [" + level + "]", foundOrphan);
- if (maxConn > nodesCount) {
- // assert that the graph is fully connected,
- // i.e. any node can be reached from any other node
- assertConnected(graph);
- } else {
- // assert that max-connections was respected
- assertMaxConn(graph, maxConn);
- }
- }
- }
- }
-
- /** Assert that every node is reachable from some other node */
- private static void assertConnected(int[][] graph) {
- List<Integer> nodes = new ArrayList<>();
- Set<Integer> visited = new HashSet<>();
- List<Integer> queue = new LinkedList<>();
- for (int i = 0; i < graph.length; i++) {
- if (graph[i] != null) {
- nodes.add(i);
- }
- }
-
- // start from any node
- int startIdx = random().nextInt(nodes.size());
- queue.add(nodes.get(startIdx));
- while (queue.isEmpty() == false) {
- int i = queue.remove(0);
- assertNotNull("expected neighbors of " + i, graph[i]);
- visited.add(i);
- for (int j : graph[i]) {
- if (visited.contains(j) == false) {
- queue.add(j);
- }
- }
- }
- // assert that every node is reachable from some other node as it was visited
- for (int node : nodes) {
- assertTrue(
- "Attempted to walk entire graph but never visited node [" + node + "]",
- visited.contains(node));
- }
- }
-
- private static float[][] createRandomVectors(int size, int dim, Random random) {
- float[][] vectors = new float[size][];
- for (int offset = 0; offset < size; offset++) {
- float[] vec = new float[dim];
- for (int i = 0; i < dim; i++) {
- vec[i] = random.nextFloat();
- }
- VectorUtil.l2normalize(vec);
- vectors[offset] = vec;
- }
- return vectors;
- }
-}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index 1d69e6d..6d2f827 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -43,6 +43,7 @@ import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
@@ -68,7 +69,7 @@ public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
HnswGraphBuilder builder =
- new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, 0, seed);
+ new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors);
// Recreate the graph while indexing with the same random seed and write it out
@@ -115,32 +116,50 @@ public class TestHnswGraph extends LuceneTestCase {
((CodecReader) ctx.reader()).getVectorReader())
.getFieldReader("field"))
.getGraphValues("field");
- assertGraphEqual(hnsw, graphValues, nVec);
+ assertGraphEqual(hnsw, graphValues);
}
}
}
}
+ private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h) throws IOException {
+ assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
+ assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
+
+ // assert equal nodes on each level
+ for (int level = 0; level < g.numLevels(); level++) {
+ DocIdSetIterator nodesOnLevel = g.getAllNodesOnLevel(level);
+ DocIdSetIterator nodesOnLevel2 = h.getAllNodesOnLevel(level);
+ for (int node = nodesOnLevel.nextDoc(), node2 = nodesOnLevel2.nextDoc();
+ node != DocIdSetIterator.NO_MORE_DOCS && node2 != DocIdSetIterator.NO_MORE_DOCS;
+ node = nodesOnLevel.nextDoc(), node2 = nodesOnLevel2.nextDoc()) {
+ assertEquals("nodes in the graphs are different", node, node2);
+ }
+ }
+
+ // assert equal nodes' neighbours on each level
+ for (int level = 0; level < g.numLevels(); level++) {
+ DocIdSetIterator nodesOnLevel = g.getAllNodesOnLevel(level);
+ for (int node = nodesOnLevel.nextDoc();
+ node != DocIdSetIterator.NO_MORE_DOCS;
+ node = nodesOnLevel.nextDoc()) {
+ g.seek(level, node);
+ h.seek(level, node);
+ assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h));
+ }
+ }
+ }
+
// Make sure we actually approximately find the closest k elements. Mostly this is about
// ensuring that we have all the distance functions, comparators, priority queues and so on
// oriented in the right directions
public void testAknnDiverse() throws IOException {
int maxConn = 10;
- // single level graph
- double ml = 0;
- doTestAknnDiverse(maxConn, ml);
-
- // multi level graph
- ml = 1 / Math.log(1.0 * maxConn);
- doTestAknnDiverse(maxConn, ml);
- }
-
- private void doTestAknnDiverse(int maxConn, double ml) throws IOException {
int nDoc = 100;
TestHnswGraph.CircularVectorValues vectors = new TestHnswGraph.CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
- vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, ml, random().nextInt());
+ vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
HnswGraph hnsw = builder.build(vectors);
// run some searches
NeighborQueue nn =
@@ -176,33 +195,28 @@ public class TestHnswGraph extends LuceneTestCase {
CircularVectorValues vectors = new CircularVectorValues(nDoc);
// the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
-
- double ml1 = 0; // single level graph
- double ml2 = 1 / Math.log(1.0 * maxConn); // multi level graph
- for (double ml : new double[] {ml1, ml2}) {
- HnswGraphBuilder builder =
- new HnswGraphBuilder(
- vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, ml, random().nextInt());
- HnswGraph hnsw = builder.build(vectors);
- NeighborQueue nn =
- HnswGraph.search(
- new float[] {1, 0},
- 10,
- 5,
- vectors.randomAccess(),
- VectorSimilarityFunction.DOT_PRODUCT,
- hnsw,
- acceptOrds,
- random());
- int sum = 0;
- for (int node : nn.nodes()) {
- assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
- sum += node;
- }
- // We expect to get approximately 100% recall;
- // the lowest docIds are closest to zero; sum(0,9) = 45
- assertTrue("sum(result docs)=" + sum, sum < 75);
+ HnswGraphBuilder builder =
+ new HnswGraphBuilder(
+ vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
+ HnswGraph hnsw = builder.build(vectors);
+ NeighborQueue nn =
+ HnswGraph.search(
+ new float[] {1, 0},
+ 10,
+ 5,
+ vectors.randomAccess(),
+ VectorSimilarityFunction.DOT_PRODUCT,
+ hnsw,
+ acceptOrds,
+ random());
+ int sum = 0;
+ for (int node : nn.nodes()) {
+ assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
+ sum += node;
}
+ // We expect to get approximately 100% recall;
+ // the lowest docIds are closest to zero; sum(0,9) = 45
+ assertTrue("sum(result docs)=" + sum, sum < 75);
}
public void testBoundsCheckerMax() {
@@ -230,7 +244,7 @@ public class TestHnswGraph extends LuceneTestCase {
}
public void testHnswGraphBuilderInvalid() {
- expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0, 0));
+ expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0));
expectThrows(
IllegalArgumentException.class,
() ->
@@ -239,7 +253,6 @@ public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction.EUCLIDEAN,
0,
10,
- 0,
0));
expectThrows(
IllegalArgumentException.class,
@@ -249,7 +262,6 @@ public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction.EUCLIDEAN,
10,
0,
- 0,
0));
}
@@ -268,109 +280,49 @@ public class TestHnswGraph extends LuceneTestCase {
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder builder =
new HnswGraphBuilder(
- vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, 0, random().nextInt());
+ vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
builder.addGraphNode(1, vectors.vectorValue(1));
builder.addGraphNode(2, vectors.vectorValue(2));
// now every node has tried to attach every other node as a neighbor, but
// some were excluded based on diversity check.
- assertNeighbors(builder.hnsw, 0, 1, 2);
- assertNeighbors(builder.hnsw, 1, 0);
- assertNeighbors(builder.hnsw, 2, 0);
+ assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
+ assertLevel0Neighbors(builder.hnsw, 1, 0);
+ assertLevel0Neighbors(builder.hnsw, 2, 0);
builder.addGraphNode(3, vectors.vectorValue(3));
- assertNeighbors(builder.hnsw, 0, 1, 2);
+ assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// we added 3 here
- assertNeighbors(builder.hnsw, 1, 0, 3);
- assertNeighbors(builder.hnsw, 2, 0);
- assertNeighbors(builder.hnsw, 3, 1);
+ assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
+ assertLevel0Neighbors(builder.hnsw, 2, 0);
+ assertLevel0Neighbors(builder.hnsw, 3, 1);
// supplant an existing neighbor
builder.addGraphNode(4, vectors.vectorValue(4));
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
- assertNeighbors(builder.hnsw, 0, 1, 2);
+ assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so
// replace it
- assertNeighbors(builder.hnsw, 1, 0, 4);
- assertNeighbors(builder.hnsw, 2, 0);
+ assertLevel0Neighbors(builder.hnsw, 1, 0, 4);
+ assertLevel0Neighbors(builder.hnsw, 2, 0);
// 1 survives the diversity check
- assertNeighbors(builder.hnsw, 3, 1, 4);
- assertNeighbors(builder.hnsw, 4, 1, 3);
+ assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
+ assertLevel0Neighbors(builder.hnsw, 4, 1, 3);
builder.addGraphNode(5, vectors.vectorValue(5));
- assertNeighbors(builder.hnsw, 0, 1, 2);
- assertNeighbors(builder.hnsw, 1, 0, 5);
- assertNeighbors(builder.hnsw, 2, 0);
- // even though 5 is closer, 3 is not a neighbor of 5, so no update to *its* neighbors occurs
- assertNeighbors(builder.hnsw, 3, 1, 4);
- assertNeighbors(builder.hnsw, 4, 3, 5);
- assertNeighbors(builder.hnsw, 5, 1, 4);
- }
-
- public void testDiversityHNSW() throws IOException {
- // Some carefully checked test cases with simple 2d vectors on the unit circle:
- MockVectorValues vectors =
- new MockVectorValues(
- new float[][] {
- unitVector2d(0.5),
- unitVector2d(0.75),
- unitVector2d(0.2),
- unitVector2d(0.9),
- unitVector2d(0.8),
- unitVector2d(0.77),
- });
- // First add nodes until everybody gets a full neighbor list
- int maxConn = 2;
- double ml = 1 / Math.log(1.0 * maxConn);
- HnswGraphBuilder builder =
- new HnswGraphBuilder(
- vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 10, ml, random().nextInt());
- // node 0 is added by the builder constructor
- builder.addGraphNodeHNSW(1, vectors.vectorValue(1));
- builder.addGraphNodeHNSW(2, vectors.vectorValue(2));
- // now every node has tried to attach every other node as a neighbor, but
- // some were excluded based on diversity check.
- assertNeighbors(builder.hnsw, 0, 1, 2);
- assertNeighbors(builder.hnsw, 1, 0);
- assertNeighbors(builder.hnsw, 2, 0);
-
- builder.addGraphNodeHNSW(3, vectors.vectorValue(3));
- assertNeighbors(builder.hnsw, 0, 1, 2);
- // we added 3 here
- assertNeighbors(builder.hnsw, 1, 0, 3);
- assertNeighbors(builder.hnsw, 2, 0);
- assertNeighbors(builder.hnsw, 3, 1);
-
- // supplant an existing neighbor
- builder.addGraphNodeHNSW(4, vectors.vectorValue(4));
- // 4 is the same distance from 0 that 2 is; we leave the existing node in place
- assertNeighbors(builder.hnsw, 0, 1, 2);
- // 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so
- // replace it
- assertNeighbors(builder.hnsw, 1, 0, 4);
- assertNeighbors(builder.hnsw, 2, 0);
- // 1 survives the diversity check
- assertNeighbors(builder.hnsw, 3, 1, 4);
- assertNeighbors(builder.hnsw, 4, 1, 3);
-
- builder.addGraphNodeHNSW(5, vectors.vectorValue(5));
- assertNeighbors(builder.hnsw, 0, 1, 2);
- assertNeighbors(builder.hnsw, 1, 0, 5);
- assertNeighbors(builder.hnsw, 2, 0);
+ assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
+ assertLevel0Neighbors(builder.hnsw, 1, 0, 5);
+ assertLevel0Neighbors(builder.hnsw, 2, 0);
// even though 5 is closer, 3 is not a neighbor of 5, so no update to *its* neighbors occurs
- assertNeighbors(builder.hnsw, 3, 1, 4);
- assertNeighbors(builder.hnsw, 4, 3, 5);
- assertNeighbors(builder.hnsw, 5, 1, 4);
+ assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
+ assertLevel0Neighbors(builder.hnsw, 4, 3, 5);
+ assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
}
- private void assertNeighbors(HnswGraph graph, int node, int... expected) {
- assertLevelNeighbors(graph, 0, node, expected);
- }
-
- private void assertLevelNeighbors(HnswGraph graph, int level, int node, int... expected) {
+ private void assertLevel0Neighbors(HnswGraph graph, int node, int... expected) {
Arrays.sort(expected);
- NeighborArray nn = graph.getNeighbors(level, node);
+ NeighborArray nn = graph.getNeighbors(0, node);
int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
Arrays.sort(actual);
assertArrayEquals(
@@ -383,42 +335,37 @@ public class TestHnswGraph extends LuceneTestCase {
int size = atLeast(100);
int dim = atLeast(10);
int maxConn = 10;
- double ml1 = 0; // single level graph
- double ml2 = 1 / Math.log(1.0 * maxConn); // multi level graph
RandomVectorValues vectors = new RandomVectorValues(size, dim, random());
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
int topK = 5;
+ HnswGraphBuilder builder =
+ new HnswGraphBuilder(vectors, similarityFunction, maxConn, 30, random().nextLong());
+ HnswGraph hnsw = builder.build(vectors);
+ Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
- for (double ml : new double[] {ml1, ml2}) {
- HnswGraphBuilder builder =
- new HnswGraphBuilder(vectors, similarityFunction, maxConn, 30, ml, random().nextLong());
- HnswGraph hnsw = builder.build(vectors);
- Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
-
- int totalMatches = 0;
- for (int i = 0; i < 100; i++) {
- float[] query = randomVector(random(), dim);
- NeighborQueue actual =
- HnswGraph.search(
- query, topK, 100, vectors, similarityFunction, hnsw, acceptOrds, random());
- NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
- for (int j = 0; j < size; j++) {
- if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
- expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
- if (expected.size() > topK) {
- expected.pop();
- }
+ int totalMatches = 0;
+ for (int i = 0; i < 100; i++) {
+ float[] query = randomVector(random(), dim);
+ NeighborQueue actual =
+ HnswGraph.search(
+ query, topK, 100, vectors, similarityFunction, hnsw, acceptOrds, random());
+ NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
+ for (int j = 0; j < size; j++) {
+ if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
+ expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
+ if (expected.size() > topK) {
+ expected.pop();
}
}
- assertEquals(topK, actual.size());
- totalMatches += computeOverlap(actual.nodes(), expected.nodes());
}
- double overlap = totalMatches / (double) (100 * topK);
- System.out.println("overlap=" + overlap + " totalMatches=" + totalMatches);
- assertTrue("overlap=" + overlap, overlap > 0.9);
+ assertEquals(topK, actual.size());
+ totalMatches += computeOverlap(actual.nodes(), expected.nodes());
}
+ double overlap = totalMatches / (double) (100 * topK);
+ System.out.println("overlap=" + overlap + " totalMatches=" + totalMatches);
+ assertTrue("overlap=" + overlap, overlap > 0.9);
}
private int computeOverlap(int[] a, int[] b) {
@@ -522,14 +469,6 @@ public class TestHnswGraph extends LuceneTestCase {
return value;
}
- private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException {
- for (int node = 0; node < size; node++) {
- g.seek(0, node);
- h.seek(0, node);
- assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h));
- }
- }
-
private Set<Integer> getNeighborNodes(KnnGraphValues g) throws IOException {
Set<Integer> neighbors = new HashSet<>();
for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) {