You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ma...@apache.org on 2022/05/12 19:59:35 UTC
[lucene] branch branch_9x updated: LUCENE-10527 Use 2*maxConn for last layer in HNSW (#872) (#887)
This is an automated email from the ASF dual-hosted git repository.
mayya pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/branch_9x by this push:
new 88c37ef26ba LUCENE-10527 Use 2*maxConn for last layer in HNSW (#872) (#887)
88c37ef26ba is described below
commit 88c37ef26bad4714c1129638490dd8a2e4e46483
Author: Mayya Sharipova <ma...@elastic.co>
AuthorDate: Thu May 12 15:59:28 2022 -0400
LUCENE-10527 Use 2*maxConn for last layer in HNSW (#872) (#887)
The original HNSW paper (https://arxiv.org/pdf/1603.09320.pdf) suggests
to use a different maxConn for the upper layers vs. the bottom one
(which contains the full neighborhood graph). Specifically, they
suggest using maxConn=M for upper layers and maxConn=2*M for the bottom.
This patch ensures that we follow this recommendation and use
maxConn=2*M for the bottom layer.
---
lucene/CHANGES.txt | 2 +
.../lucene91/Lucene91HnswGraphBuilder.java} | 85 +++++++++++++--------
.../lucene91/Lucene91NeighborArray.java | 89 ++++++++++++++++++++++
.../lucene91/Lucene91OnHeapHnswGraph.java} | 20 ++---
.../lucene91/Lucene91HnswVectorsWriter.java | 23 +++---
.../codecs/lucene92/Lucene92HnswVectorsFormat.java | 6 +-
.../codecs/lucene92/Lucene92HnswVectorsReader.java | 18 +++--
.../codecs/lucene92/Lucene92HnswVectorsWriter.java | 14 ++--
.../apache/lucene/util/hnsw/HnswGraphBuilder.java | 31 ++++----
.../apache/lucene/util/hnsw/HnswGraphSearcher.java | 4 +-
.../org/apache/lucene/util/hnsw/NeighborQueue.java | 2 +-
.../apache/lucene/util/hnsw/OnHeapHnswGraph.java | 25 +++---
.../test/org/apache/lucene/index/TestKnnGraph.java | 23 +++---
.../org/apache/lucene/util/hnsw/TestHnswGraph.java | 29 +++----
14 files changed, 246 insertions(+), 125 deletions(-)
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index d68507c9fe4..6fb283814b1 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -56,6 +56,8 @@ Improvements
* LUCENE-9848: Correctly sort HNSW graph neighbors when applying diversity criterion (Mayya
Sharipova, Michael Sokolov)
+* LUCENE-10527: Use 2*maxConn for the last layer in HNSW (Mayya Sharipova)
+
Optimizations
---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
similarity index 77%
copy from lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
copy to lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
index b7935497f14..002497d2d2a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.lucene.util.hnsw;
+package org.apache.lucene.backward_codecs.lucene91;
import static java.lang.Math.log;
@@ -28,12 +28,16 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
+import org.apache.lucene.util.hnsw.BoundsChecker;
+import org.apache.lucene.util.hnsw.HnswGraph;
+import org.apache.lucene.util.hnsw.HnswGraphSearcher;
+import org.apache.lucene.util.hnsw.NeighborQueue;
/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
* hyperparameters.
*/
-public final class HnswGraphBuilder {
+public final class Lucene91HnswGraphBuilder {
/** Default random seed for level generation * */
private static final long DEFAULT_RAND_SEED = 42;
@@ -46,7 +50,7 @@ public final class HnswGraphBuilder {
private final int maxConn;
private final int beamWidth;
private final double ml;
- private final NeighborArray scratch;
+ private final Lucene91NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
@@ -54,7 +58,7 @@ public final class HnswGraphBuilder {
private final BoundsChecker bound;
private final HnswGraphSearcher graphSearcher;
- final OnHeapHnswGraph hnsw;
+ final Lucene91OnHeapHnswGraph hnsw;
private InfoStream infoStream = InfoStream.getDefault();
@@ -74,7 +78,7 @@ public final class HnswGraphBuilder {
* @param seed the seed for a random number generator used during graph construction. Provide this
* to ensure repeatable construction.
*/
- public HnswGraphBuilder(
+ public Lucene91HnswGraphBuilder(
RandomAccessVectorValuesProducer vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
@@ -96,15 +100,14 @@ public final class HnswGraphBuilder {
this.ml = 1 / Math.log(1.0 * maxConn);
this.random = new SplittableRandom(seed);
int levelOfFirstNode = getRandomGraphLevel(ml, random);
- this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode, similarityFunction.reversed);
+ this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode);
this.graphSearcher =
new HnswGraphSearcher(
similarityFunction,
new NeighborQueue(beamWidth, similarityFunction.reversed == false),
new FixedBitSet(vectorValues.size()));
bound = BoundsChecker.create(similarityFunction.reversed);
- // in scratch we store candidates in reverse order: worse candidates are first
- scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1), similarityFunction.reversed);
+ scratch = new Lucene91NeighborArray(Math.max(beamWidth, maxConn + 1));
}
/**
@@ -115,7 +118,7 @@ public final class HnswGraphBuilder {
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors
*/
- public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
+ public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@@ -178,32 +181,36 @@ public final class HnswGraphBuilder {
return now;
}
+ /* TODO: we are not maintaining nodes in strict score order; the forward links
+ * are added in sorted order, but the reverse implicit ones are not. Diversity heuristic should
+ * work better if we keep the neighbor arrays sorted. Possibly we should switch back to a heap?
+ * But first we should just see if sorting makes a significant difference.
+ */
private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
throws IOException {
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
* since the node is new and has no prior neighbors).
*/
- NeighborArray neighbors = hnsw.getNeighbors(level, node);
+ Lucene91NeighborArray neighbors = hnsw.getNeighbors(level, node);
assert neighbors.size() == 0; // new node
popToScratch(candidates);
- selectAndLinkDiverse(neighbors, scratch);
+ selectDiverse(neighbors, scratch);
// Link the selected nodes to the new node, and the new node to the selected nodes (again
// applying diversity heuristic)
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
- NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
- nbrNbr.insertSorted(node, neighbors.score[i]);
+ Lucene91NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
+ nbrNbr.add(node, neighbors.score[i]);
if (nbrNbr.size() > maxConn) {
- int indexToRemove = findWorstNonDiverse(nbrNbr);
- nbrNbr.removeIndex(indexToRemove);
+ diversityUpdate(nbrNbr);
}
}
}
- private void selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates)
+ private void selectDiverse(Lucene91NeighborArray neighbors, Lucene91NeighborArray candidates)
throws IOException {
// Select the best maxConn neighbors of the new node, applying the diversity heuristic
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
@@ -241,7 +248,7 @@ public final class HnswGraphBuilder {
private boolean diversityCheck(
float[] candidate,
float score,
- NeighborArray neighbors,
+ Lucene91NeighborArray neighbors,
RandomAccessVectorValues vectorValues)
throws IOException {
bound.set(score);
@@ -255,26 +262,44 @@ public final class HnswGraphBuilder {
return true;
}
- /**
- * Find first non-diverse neighbour among the list of neighbors starting from the most distant
- * neighbours
- */
- private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
- for (int i = neighbors.size() - 1; i > 0; i--) {
- int cNode = neighbors.node[i];
- float[] cVector = vectorValues.vectorValue(cNode);
+ private void diversityUpdate(Lucene91NeighborArray neighbors) throws IOException {
+ assert neighbors.size() == maxConn + 1;
+ int replacePoint = findNonDiverse(neighbors);
+ if (replacePoint == -1) {
+ // none found; check score against worst existing neighbor
+ bound.set(neighbors.score[0]);
+ if (bound.check(neighbors.score[maxConn])) {
+ // drop the new neighbor; it is not competitive and there were no diversity failures
+ neighbors.removeLast();
+ return;
+ } else {
+ replacePoint = 0;
+ }
+ }
+ neighbors.node[replacePoint] = neighbors.node[maxConn];
+ neighbors.score[replacePoint] = neighbors.score[maxConn];
+ neighbors.removeLast();
+ }
+
+ // scan neighbors looking for diversity violations
+ private int findNonDiverse(Lucene91NeighborArray neighbors) throws IOException {
+ for (int i = neighbors.size() - 1; i >= 0; i--) {
+ // check each neighbor against its better-scoring neighbors. If it fails diversity check with
+ // them, drop it
+ int nbrNode = neighbors.node[i];
bound.set(neighbors.score[i]);
- // check the candidate against its better-scoring neighbors
- for (int j = i - 1; j >= 0; j--) {
+ float[] nbrVector = vectorValues.vectorValue(nbrNode);
+ for (int j = maxConn; j > i; j--) {
float diversityCheck =
- similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
- // node i is too similar to node j given its score relative to the base node
+ similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
if (bound.check(diversityCheck) == false) {
+ // node j is too similar to node i given its score relative to the base node
+ // replace it with the new node, which is at [maxConn]
return i;
}
}
}
- return neighbors.size() - 1;
+ return -1;
}
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91NeighborArray.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91NeighborArray.java
new file mode 100644
index 00000000000..fcb097162f1
--- /dev/null
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91NeighborArray.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.backward_codecs.lucene91;
+
+import org.apache.lucene.util.ArrayUtil;
+
+/**
+ * NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair
+ * of growable arrays.
+ *
+ * @lucene.internal
+ */
+public class Lucene91NeighborArray {
+
+ private int size;
+
+ float[] score;
+ int[] node;
+
+ /** Create a neighbour array with the given initial size */
+ public Lucene91NeighborArray(int maxSize) {
+ node = new int[maxSize];
+ score = new float[maxSize];
+ }
+
+ /** Add a new node with a score */
+ public void add(int newNode, float newScore) {
+ if (size == node.length - 1) {
+ node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
+ score = ArrayUtil.growExact(score, node.length);
+ }
+ node[size] = newNode;
+ score[size] = newScore;
+ ++size;
+ }
+
+ /** Get the size, the number of nodes added so far */
+ public int size() {
+ return size;
+ }
+
+ /**
+ * Direct access to the internal list of node ids; provided for efficient writing of the graph
+ *
+ * @lucene.internal
+ */
+ public int[] node() {
+ return node;
+ }
+
+ /**
+ * Direct access to the internal list of scores
+ *
+ * @lucene.internal
+ */
+ public float[] score() {
+ return score;
+ }
+
+ /** Clear all the nodes in the array */
+ public void clear() {
+ size = 0;
+ }
+
+ /** Remove the last nodes from the array */
+ public void removeLast() {
+ size--;
+ }
+
+ @Override
+ public String toString() {
+ return "NeighborArray[" + size + "]";
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java
similarity index 89%
copy from lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
copy to lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java
index 08cecd1f8ac..2d3ef582b47 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.lucene.util.hnsw;
+package org.apache.lucene.backward_codecs.lucene91;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@@ -23,15 +23,16 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.hnsw.HnswGraph;
+import org.apache.lucene.util.hnsw.NeighborQueue;
/**
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
* construct the HNSW graph before it's written to the index.
*/
-public final class OnHeapHnswGraph extends HnswGraph {
+public final class Lucene91OnHeapHnswGraph extends HnswGraph {
private final int maxConn;
- private final boolean similarityReversed;
private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level
@@ -44,15 +45,14 @@ public final class OnHeapHnswGraph extends HnswGraph {
// Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
// added to HnswBuilder, and the node values are the ordinals of those vectors.
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
- private final List<List<NeighborArray>> graph;
+ private final List<List<Lucene91NeighborArray>> graph;
// KnnGraphValues iterator members
private int upto;
- private NeighborArray cur;
+ private Lucene91NeighborArray cur;
- OnHeapHnswGraph(int maxConn, int levelOfFirstNode, boolean similarityReversed) {
+ Lucene91OnHeapHnswGraph(int maxConn, int levelOfFirstNode) {
this.maxConn = maxConn;
- this.similarityReversed = similarityReversed;
this.numLevels = levelOfFirstNode + 1;
this.graph = new ArrayList<>(numLevels);
this.entryNode = 0;
@@ -61,7 +61,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
// Typically with diversity criteria we see nodes not fully occupied;
// average fanout seems to be about 1/2 maxConn.
// There is some indexing time penalty for under-allocating, but saves RAM
- graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4), similarityReversed == false));
+ graph.get(i).add(new Lucene91NeighborArray(Math.max(32, maxConn / 4)));
}
this.nodesByLevel = new ArrayList<>(numLevels);
@@ -77,7 +77,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
* @param level level of the graph
* @param node the node whose neighbors are returned, represented as an ordinal on the level 0.
*/
- public NeighborArray getNeighbors(int level, int node) {
+ public Lucene91NeighborArray getNeighbors(int level, int node) {
if (level == 0) {
return graph.get(level).get(node);
}
@@ -122,7 +122,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
}
}
- graph.get(level).add(new NeighborArray(maxConn + 1, similarityReversed == false));
+ graph.get(level).add(new Lucene91NeighborArray(maxConn + 1));
}
@Override
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java
index 6e1527541b5..0542163057e 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java
@@ -37,9 +37,6 @@ import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
-import org.apache.lucene.util.hnsw.HnswGraphBuilder;
-import org.apache.lucene.util.hnsw.NeighborArray;
-import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
/**
* Writes vector values and knn graphs to index segments.
@@ -145,7 +142,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
Lucene91HnswVectorsReader.OffHeapVectorValues offHeapVectors =
new Lucene91HnswVectorsReader.OffHeapVectorValues(
vectors.dimension(), docsWithField.cardinality(), null, vectorDataInput);
- OnHeapHnswGraph graph =
+ Lucene91OnHeapHnswGraph graph =
offHeapVectors.size() == 0
? null
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
@@ -194,7 +191,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
long vectorIndexOffset,
long vectorIndexLength,
DocsWithFieldSet docsWithField,
- OnHeapHnswGraph graph)
+ Lucene91OnHeapHnswGraph graph)
throws IOException {
meta.writeInt(field.number);
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
@@ -236,16 +233,20 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
}
}
- private OnHeapHnswGraph writeGraph(
+ private Lucene91OnHeapHnswGraph writeGraph(
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
// build graph
- HnswGraphBuilder hnswGraphBuilder =
- new HnswGraphBuilder(
- vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
+ Lucene91HnswGraphBuilder hnswGraphBuilder =
+ new Lucene91HnswGraphBuilder(
+ vectorValues,
+ similarityFunction,
+ maxConn,
+ beamWidth,
+ Lucene91HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
- OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
+ Lucene91OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();
@@ -253,7 +254,7 @@ public final class Lucene91HnswVectorsWriter extends KnnVectorsWriter {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
- NeighborArray neighbors = graph.getNeighbors(level, node);
+ Lucene91NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
vectorIndex.writeInt(size);
// Destructively modify; it's ok we are discarding it after this
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsFormat.java
index d762033e96d..3b28b706890 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsFormat.java
@@ -58,9 +58,9 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* <ul>
* <li><b>[int32]</b> the number of neighbor nodes
* <li><b>array[int32]</b> the neighbor ordinals
- * <li><b>array[int32]</b> padding from empty integers if the number of neighbors less
- * than the maximum number of connections (maxConn). Padding is equal to
- * ((maxConn-the number of neighbours) * 4) bytes.
+ * <li><b>array[int32]</b> padding if the number of the node's neighbors is less than
+ * the maximum number of connections allowed on this level. Padding is equal to
+ * ((maxConnOnLevel – the number of neighbours) * 4) bytes.
* </ul>
* </ul>
* </ul>
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsReader.java
index a03b0b9b32b..1b1a120a6b2 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsReader.java
@@ -282,7 +282,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
final long vectorDataLength;
final long vectorIndexOffset;
final long vectorIndexLength;
- final int maxConn;
+ final int M;
final int numLevels;
final int dimension;
final int size;
@@ -336,7 +336,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
}
// read nodes by level
- maxConn = input.readInt();
+ M = input.readInt();
numLevels = input.readInt();
nodesByLevel = new int[numLevels][];
for (int level = 0; level < numLevels; level++) {
@@ -359,10 +359,13 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
for (int level = 0; level < numLevels; level++) {
if (level == 0) {
graphOffsetsByLevel[level] = 0;
+ } else if (level == 1) {
+ int numNodesOnLevel0 = size;
+ graphOffsetsByLevel[level] = (1 + (M * 2)) * Integer.BYTES * numNodesOnLevel0;
} else {
- int numNodesOnPrevLevel = level == 1 ? size : nodesByLevel[level - 1].length;
+ int numNodesOnPrevLevel = nodesByLevel[level - 1].length;
graphOffsetsByLevel[level] =
- graphOffsetsByLevel[level - 1] + (1 + maxConn) * Integer.BYTES * numNodesOnPrevLevel;
+ graphOffsetsByLevel[level - 1] + (1 + M) * Integer.BYTES * numNodesOnPrevLevel;
}
}
}
@@ -382,6 +385,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
final int entryNode;
final int size;
final long bytesForConns;
+ final long bytesForConns0;
int arcCount;
int arcUpTo;
@@ -394,7 +398,8 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
this.entryNode = numLevels > 1 ? nodesByLevel[numLevels - 1][0] : 0;
this.size = entry.size();
this.graphOffsetsByLevel = entry.graphOffsetsByLevel;
- this.bytesForConns = ((long) entry.maxConn + 1) * Integer.BYTES;
+ this.bytesForConns = ((long) entry.M + 1) * Integer.BYTES;
+ this.bytesForConns0 = ((long) (entry.M * 2) + 1) * Integer.BYTES;
}
@Override
@@ -404,7 +409,8 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
? targetOrd
: Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
assert targetIndex >= 0;
- long graphDataOffset = graphOffsetsByLevel[level] + targetIndex * bytesForConns;
+ long graphDataOffset =
+ graphOffsetsByLevel[level] + targetIndex * (level == 0 ? bytesForConns0 : bytesForConns);
// unsafe; no bounds checking
dataIn.seek(graphDataOffset);
arcCount = dataIn.readInt();
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsWriter.java
index c2fc9b35975..e63ab2ce277 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene92/Lucene92HnswVectorsWriter.java
@@ -55,13 +55,12 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
private final IndexOutput meta, vectorData, vectorIndex;
private final int maxDoc;
- private final int maxConn;
+ private final int M;
private final int beamWidth;
private boolean finished;
- Lucene92HnswVectorsWriter(SegmentWriteState state, int maxConn, int beamWidth)
- throws IOException {
- this.maxConn = maxConn;
+ Lucene92HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException {
+ this.M = M;
this.beamWidth = beamWidth;
assert state.fieldInfos.hasVectorValues();
@@ -248,7 +247,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
meta.writeLong(vectorData.getFilePointer() - start);
}
- meta.writeInt(maxConn);
+ meta.writeInt(M);
// write graph nodes on each level
if (graph == null) {
meta.writeInt(0);
@@ -274,13 +273,14 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
// build graph
HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(
- vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
+ vectorValues, similarityFunction, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();
for (int level = 0; level < graph.numLevels(); level++) {
+ int maxConnOnLevel = level == 0 ? (M * 2) : M;
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
while (nodesOnLevel.hasNext()) {
int node = nodesOnLevel.nextInt();
@@ -297,7 +297,7 @@ public final class Lucene92HnswVectorsWriter extends KnnVectorsWriter {
}
// if number of connections < maxConn, add bogus values up to maxConn to have predictable
// offsets
- for (int i = size; i < maxConn; i++) {
+ for (int i = size; i < maxConnOnLevel; i++) {
vectorIndex.writeInt(0);
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index b7935497f14..b611d082c96 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -43,7 +43,7 @@ public final class HnswGraphBuilder {
/** Random seed for level generation; public to expose for testing * */
public static long randSeed = DEFAULT_RAND_SEED;
- private final int maxConn;
+ private final int M; // max number of connections on upper layers
private final int beamWidth;
private final double ml;
private final NeighborArray scratch;
@@ -68,8 +68,8 @@ public final class HnswGraphBuilder {
*
* @param vectors the vectors whose relations are represented by the graph - must provide a
* different view over those vectors than the one used to add via addGraphNode.
- * @param maxConn the number of connections to make when adding a new graph node; roughly speaking
- * the graph fanout.
+ * @param M – graph fanout parameter used to calculate the maximum number of connections a node
+ * can have – M on upper layers, and M * 2 on the lowest level.
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
* @param seed the seed for a random number generator used during graph construction. Provide this
* to ensure repeatable construction.
@@ -77,26 +77,26 @@ public final class HnswGraphBuilder {
public HnswGraphBuilder(
RandomAccessVectorValuesProducer vectors,
VectorSimilarityFunction similarityFunction,
- int maxConn,
+ int M,
int beamWidth,
long seed)
throws IOException {
vectorValues = vectors.randomAccess();
buildVectors = vectors.randomAccess();
this.similarityFunction = Objects.requireNonNull(similarityFunction);
- if (maxConn <= 0) {
+ if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
}
if (beamWidth <= 0) {
throw new IllegalArgumentException("beamWidth must be positive");
}
- this.maxConn = maxConn;
+ this.M = M;
this.beamWidth = beamWidth;
// normalization factor for level generation; currently not configurable
- this.ml = 1 / Math.log(1.0 * maxConn);
+ this.ml = 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
int levelOfFirstNode = getRandomGraphLevel(ml, random);
- this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode, similarityFunction.reversed);
+ this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode, similarityFunction.reversed);
this.graphSearcher =
new HnswGraphSearcher(
similarityFunction,
@@ -104,7 +104,7 @@ public final class HnswGraphBuilder {
new FixedBitSet(vectorValues.size()));
bound = BoundsChecker.create(similarityFunction.reversed);
// in scratch we store candidates in reverse order: worse candidates are first
- scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1), similarityFunction.reversed);
+ scratch = new NeighborArray(Math.max(beamWidth, M + 1), similarityFunction.reversed);
}
/**
@@ -187,7 +187,8 @@ public final class HnswGraphBuilder {
NeighborArray neighbors = hnsw.getNeighbors(level, node);
assert neighbors.size() == 0; // new node
popToScratch(candidates);
- selectAndLinkDiverse(neighbors, scratch);
+ int maxConnOnLevel = level == 0 ? M * 2 : M;
+ selectAndLinkDiverse(neighbors, scratch, maxConnOnLevel);
// Link the selected nodes to the new node, and the new node to the selected nodes (again
// applying diversity heuristic)
@@ -196,17 +197,17 @@ public final class HnswGraphBuilder {
int nbr = neighbors.node[i];
NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
nbrNbr.insertSorted(node, neighbors.score[i]);
- if (nbrNbr.size() > maxConn) {
+ if (nbrNbr.size() > maxConnOnLevel) {
int indexToRemove = findWorstNonDiverse(nbrNbr);
nbrNbr.removeIndex(indexToRemove);
}
}
}
- private void selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates)
- throws IOException {
- // Select the best maxConn neighbors of the new node, applying the diversity heuristic
- for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
+ private void selectAndLinkDiverse(
+ NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException {
+ // Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic
+ for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) {
// compare each neighbor (in distance order) against the closer neighbors selected so far,
// only adding it if it is closer to the target than to any of the other selected neighbors
int cNode = candidates.node[i];
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index e02fbcd9bda..b1a2436166f 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -47,7 +47,7 @@ public final class HnswGraphSearcher {
* @param candidates max heap that will track the candidate nodes to explore
* @param visited bit set that will track nodes that have already been visited
*/
- HnswGraphSearcher(
+ public HnswGraphSearcher(
VectorSimilarityFunction similarityFunction, NeighborQueue candidates, BitSet visited) {
this.similarityFunction = similarityFunction;
this.candidates = candidates;
@@ -112,7 +112,7 @@ public final class HnswGraphSearcher {
* @param graph the graph values
* @return a priority queue holding the closest neighbors found
*/
- NeighborQueue searchLevel(
+ public NeighborQueue searchLevel(
float[] query,
int topK,
int level,
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
index cb58c608f61..a2c7253261b 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
@@ -98,7 +98,7 @@ public class NeighborQueue {
return (int) order.apply(heap.pop());
}
- int[] nodes() {
+ public int[] nodes() {
int size = size();
int[] nodes = new int[size];
for (int i = 0; i < size; i++) {
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index 08cecd1f8ac..1dc0845ccd5 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -30,7 +30,6 @@ import org.apache.lucene.util.ArrayUtil;
*/
public final class OnHeapHnswGraph extends HnswGraph {
- private final int maxConn;
private final boolean similarityReversed;
private int numLevels; // the current number of levels in the graph
private int entryNode; // the current graph entry node on the top level
@@ -41,27 +40,30 @@ public final class OnHeapHnswGraph extends HnswGraph {
// graph is a list of graph levels.
// Each level is represented as List<NeighborArray> – nodes' connections on this level.
- // Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
+ // Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond
+ // to vectors
// added to HnswBuilder, and the node values are the ordinals of those vectors.
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
private final List<List<NeighborArray>> graph;
+ private final int nsize;
+ private final int nsize0;
// KnnGraphValues iterator members
private int upto;
private NeighborArray cur;
- OnHeapHnswGraph(int maxConn, int levelOfFirstNode, boolean similarityReversed) {
- this.maxConn = maxConn;
+ OnHeapHnswGraph(int M, int levelOfFirstNode, boolean similarityReversed) {
this.similarityReversed = similarityReversed;
this.numLevels = levelOfFirstNode + 1;
this.graph = new ArrayList<>(numLevels);
this.entryNode = 0;
- for (int i = 0; i < numLevels; i++) {
+ // Neighbours' size on upper levels (nsize) and level 0 (nsize0)
+ // We allocate extra space for neighbours, but then prune them to keep allowed maximum
+ this.nsize = M + 1;
+ this.nsize0 = (M * 2 + 1);
+ for (int l = 0; l < numLevels; l++) {
graph.add(new ArrayList<>());
- // Typically with diversity criteria we see nodes not fully occupied;
- // average fanout seems to be about 1/2 maxConn.
- // There is some indexing time penalty for under-allocating, but saves RAM
- graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4), similarityReversed == false));
+ graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, similarityReversed == false));
}
this.nodesByLevel = new ArrayList<>(numLevels);
@@ -121,8 +123,9 @@ public final class OnHeapHnswGraph extends HnswGraph {
}
}
}
-
- graph.get(level).add(new NeighborArray(maxConn + 1, similarityReversed == false));
+ graph
+ .get(level)
+ .add(new NeighborArray(level == 0 ? nsize0 : nsize, similarityReversed == false));
}
@Override
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index b2bf9cab30c..3a54c407070 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -64,7 +64,7 @@ public class TestKnnGraph extends LuceneTestCase {
private static final String KNN_GRAPH_FIELD = "vector";
- private static int maxConn = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
+ private static int M = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
private Codec codec;
private VectorSimilarityFunction similarityFunction;
@@ -73,15 +73,14 @@ public class TestKnnGraph extends LuceneTestCase {
public void setup() {
randSeed = random().nextLong();
if (random().nextBoolean()) {
- maxConn = random().nextInt(256) + 3;
+ M = random().nextInt(256) + 3;
}
codec =
new Lucene92Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene92HnswVectorsFormat(
- maxConn, Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
+ return new Lucene92HnswVectorsFormat(M, Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
}
};
@@ -91,7 +90,7 @@ public class TestKnnGraph extends LuceneTestCase {
@After
public void cleanup() {
- maxConn = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
+ M = Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN;
}
/** Basic test of creating documents in a graph */
@@ -263,7 +262,7 @@ public class TestKnnGraph extends LuceneTestCase {
int[][][] copyGraph(HnswGraph graphValues) throws IOException {
int[][][] graph = new int[graphValues.numLevels()][][];
int size = graphValues.size();
- int[] scratch = new int[maxConn];
+ int[] scratch = new int[M * 2];
for (int level = 0; level < graphValues.numLevels(); level++) {
NodesIterator nodesItr = graphValues.getNodesOnLevel(level);
@@ -483,10 +482,13 @@ public class TestKnnGraph extends LuceneTestCase {
// For each level of the graph assert that:
// 1. There are no orphan nodes without any friends
// 2. If orphans are found, than the level must contain only 0 or a single node
- // 3. If the number of nodes on the level doesn't exceed maxConn, assert that the graph is
+ // 3. If the number of nodes on the level doesn't exceed maxConnOnLevel, assert that the
+ // graph is
// fully connected, i.e. any node is reachable from any other node.
- // 4. If the number of nodes on the level exceeds maxConn, assert that maxConn is respected.
+ // 4. If the number of nodes on the level exceeds maxConnOnLevel, assert that maxConnOnLevel
+ // is respected.
for (int level = 0; level < graphValues.numLevels(); level++) {
+ int maxConnOnLevel = level == 0 ? M * 2 : M;
int[][] graphOnLevel = new int[graphValues.size()][];
int countOnLevel = 0;
boolean foundOrphan = false;
@@ -508,7 +510,6 @@ public class TestKnnGraph extends LuceneTestCase {
}
countOnLevel++;
}
- // System.out.println("Level[" + level + "] has [" + nodesCount + "] nodes.");
assertEquals(nodesItr.size(), countOnLevel);
assertFalse("No nodes on level [" + level + "]", countOnLevel == 0);
if (countOnLevel == 1) {
@@ -517,13 +518,13 @@ public class TestKnnGraph extends LuceneTestCase {
} else {
assertFalse(
"Graph has orphan nodes with no friends on level [" + level + "]", foundOrphan);
- if (maxConn > countOnLevel) {
+ if (maxConnOnLevel > countOnLevel) {
// assert that the graph is fully connected,
// i.e. any node can be reached from any other node
assertConnected(graphOnLevel);
} else {
// assert that max-connections was respected
- assertMaxConn(graphOnLevel, maxConn);
+ assertMaxConn(graphOnLevel, maxConnOnLevel);
}
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index a19ffdc583a..ed1b31d9861 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -62,14 +62,14 @@ public class TestHnswGraph extends LuceneTestCase {
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
- int maxConn = random().nextInt(10) + 5;
+ int M = random().nextInt(10) + 5;
int beamWidth = random().nextInt(10) + 5;
long seed = random().nextLong();
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
HnswGraphBuilder builder =
- new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, seed);
+ new HnswGraphBuilder(vectors, similarityFunction, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors);
// Recreate the graph while indexing with the same random seed and write it out
@@ -84,7 +84,7 @@ public class TestHnswGraph extends LuceneTestCase {
new Lucene92Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene92HnswVectorsFormat(maxConn, beamWidth);
+ return new Lucene92HnswVectorsFormat(M, beamWidth);
}
});
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
@@ -153,12 +153,11 @@ public class TestHnswGraph extends LuceneTestCase {
// ensuring that we have all the distance functions, comparators, priority queues and so on
// oriented in the right directions
public void testAknnDiverse() throws IOException {
- int maxConn = 10;
int nDoc = 100;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
- vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
+ vectors, VectorSimilarityFunction.DOT_PRODUCT, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
// run some searches
NeighborQueue nn =
@@ -193,11 +192,10 @@ public class TestHnswGraph extends LuceneTestCase {
public void testSearchWithAcceptOrds() throws IOException {
int nDoc = 100;
- int maxConn = 16;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
- vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
+ vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
// the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
@@ -224,11 +222,10 @@ public class TestHnswGraph extends LuceneTestCase {
public void testSearchWithSelectiveAcceptOrds() throws IOException {
int nDoc = 100;
- int maxConn = 16;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
- vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
+ vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
// Only mark a few vectors as accepted
BitSet acceptOrds = new FixedBitSet(vectors.size);
@@ -290,11 +287,10 @@ public class TestHnswGraph extends LuceneTestCase {
public void testVisitedLimit() throws IOException {
int nDoc = 500;
- int maxConn = 16;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
- vectors, VectorSimilarityFunction.DOT_PRODUCT, maxConn, 100, random().nextInt());
+ vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
int topK = 50;
@@ -396,9 +392,7 @@ public class TestHnswGraph extends LuceneTestCase {
builder.addGraphNode(4, vectors.vectorValue(4));
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
- // 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so
- // replace it
- assertLevel0Neighbors(builder.hnsw, 1, 0, 4);
+ assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4);
assertLevel0Neighbors(builder.hnsw, 2, 0);
// 1 survives the diversity check
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
@@ -406,11 +400,11 @@ public class TestHnswGraph extends LuceneTestCase {
builder.addGraphNode(5, vectors.vectorValue(5));
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
- assertLevel0Neighbors(builder.hnsw, 1, 0, 5);
+ assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5);
assertLevel0Neighbors(builder.hnsw, 2, 0);
// even though 5 is closer, 3 is not a neighbor of 5, so no update to *its* neighbors occurs
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
- assertLevel0Neighbors(builder.hnsw, 4, 3, 5);
+ assertLevel0Neighbors(builder.hnsw, 4, 1, 3, 5);
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
}
@@ -428,14 +422,13 @@ public class TestHnswGraph extends LuceneTestCase {
public void testRandom() throws IOException {
int size = atLeast(100);
int dim = atLeast(10);
- int maxConn = 10;
RandomVectorValues vectors = new RandomVectorValues(size, dim, random());
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
int topK = 5;
HnswGraphBuilder builder =
- new HnswGraphBuilder(vectors, similarityFunction, maxConn, 30, random().nextLong());
+ new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors);
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);