You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by so...@apache.org on 2020/12/22 14:48:49 UTC
[lucene-solr] branch master updated: LUCENE-9644: diversity
heuristic for HNSW graph neighbor selection (#2157)
This is an automated email from the ASF dual-hosted git repository.
sokolov pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git
The following commit(s) were added to refs/heads/master by this push:
new e1cd426 LUCENE-9644: diversity heuristic for HNSW graph neighbor selection (#2157)
e1cd426 is described below
commit e1cd426bce39abc4345b748d9cff5ff7fe10315f
Author: Michael Sokolov <so...@falutin.net>
AuthorDate: Tue Dec 22 09:48:24 2020 -0500
LUCENE-9644: diversity heuristic for HNSW graph neighbor selection (#2157)
* Additional options to KnnGraphTester to support benchmarking with ann-benchmarks
* switch to parallel array-based storage in HnswGraph (was using LongHeap)
---
lucene/core/.attach_pid13682 | 0
lucene/core/.attach_pid13912 | 0
.../codecs/lucene90/Lucene90VectorReader.java | 14 +-
.../codecs/lucene90/Lucene90VectorWriter.java | 31 +-
.../org/apache/lucene/index/KnnGraphValues.java | 10 +
.../src/java/org/apache/lucene/util/LongHeap.java | 80 ++---
.../org/apache/lucene/util/hnsw/BoundsChecker.java | 9 +-
.../org/apache/lucene/util/hnsw/HnswGraph.java | 129 +++----
.../apache/lucene/util/hnsw/HnswGraphBuilder.java | 239 +++++++------
.../org/apache/lucene/util/hnsw/NeighborArray.java | 73 ++++
.../hnsw/{Neighbors.java => NeighborQueue.java} | 101 +++---
.../test/org/apache/lucene/index/TestKnnGraph.java | 8 +-
.../test/org/apache/lucene/util/TestLongHeap.java | 111 ++----
.../org/apache/lucene/util/TestNumericUtils.java | 1 -
.../apache/lucene/util/hnsw/KnnGraphTester.java | 250 +++++++++++--
.../apache/lucene/util/hnsw/MockVectorValues.java | 138 ++++++++
.../test/org/apache/lucene/util/hnsw/TestHnsw.java | 394 ++++++++++-----------
.../org/apache/lucene/util/hnsw/TestNeighbors.java | 113 ++++++
18 files changed, 1066 insertions(+), 635 deletions(-)
diff --git a/lucene/core/.attach_pid13682 b/lucene/core/.attach_pid13682
new file mode 100644
index 0000000..e69de29
diff --git a/lucene/core/.attach_pid13912 b/lucene/core/.attach_pid13912
new file mode 100644
index 0000000..e69de29
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java
index 79d4dd0..6e6666d 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java
@@ -45,7 +45,7 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
-import org.apache.lucene.util.hnsw.Neighbors;
+import org.apache.lucene.util.hnsw.NeighborQueue;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@@ -357,10 +357,7 @@ public final class Lucene90VectorReader extends VectorReader {
public TopDocs search(float[] vector, int topK, int fanout) throws IOException {
// use a seed that is fixed for the index so we get reproducible results for the same query
final Random random = new Random(checksumSeed);
- Neighbors results = HnswGraph.search(vector, topK + fanout, topK + fanout, randomAccess(), getGraphValues(fieldEntry), random);
- while (results.size() > topK) {
- results.pop();
- }
+ NeighborQueue results = HnswGraph.search(vector, topK, topK + fanout, randomAccess(), getGraphValues(fieldEntry), random);
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), topK)];
boolean reversed = searchStrategy().reversed;
@@ -432,7 +429,7 @@ public final class Lucene90VectorReader extends VectorReader {
}
/** Read the nearest-neighbors graph from the index input */
- private final class IndexedKnnGraphReader extends KnnGraphValues {
+ private final static class IndexedKnnGraphReader extends KnnGraphValues {
final HnswGraphFieldEntry entry;
final IndexInput dataIn;
@@ -456,6 +453,11 @@ public final class Lucene90VectorReader extends VectorReader {
}
@Override
+ public int size() {
+ return entry.size();
+ }
+
+ @Override
public int nextNeighbor() throws IOException {
if (arcUpTo >= arcCount) {
return NO_MORE_DOCS;
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java
index f1a2da9..64424ba 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java
@@ -32,6 +32,7 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
+import org.apache.lucene.util.hnsw.NeighborArray;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@@ -41,12 +42,14 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
*/
public final class Lucene90VectorWriter extends VectorWriter {
+ private final SegmentWriteState segmentWriteState;
private final IndexOutput meta, vectorData, vectorIndex;
private boolean finished;
Lucene90VectorWriter(SegmentWriteState state) throws IOException {
assert state.fieldInfos.hasVectorValues();
+ segmentWriteState = state;
String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.META_EXTENSION);
meta = state.directory.createOutput(metaFileName, state.context);
@@ -138,18 +141,28 @@ public final class Lucene90VectorWriter extends VectorWriter {
}
private void writeGraph(IndexOutput graphData, RandomAccessVectorValuesProducer vectorValues, long graphDataOffset, long[] offsets, int count) throws IOException {
- HnswGraph graph = HnswGraphBuilder.build(vectorValues);
+ HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(vectorValues);
+ hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
+ HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
+
for (int ord = 0; ord < count; ord++) {
// write graph
offsets[ord] = graphData.getFilePointer() - graphDataOffset;
- int[] arcs = graph.getNeighborNodes(ord);
- Arrays.sort(arcs);
- graphData.writeInt(arcs.length);
- int lastArc = -1; // to make the assertion work?
- for (int arc : arcs) {
- assert arc > lastArc : "arcs out of order: " + lastArc + "," + arc;
- graphData.writeVInt(arc - lastArc);
- lastArc = arc;
+
+ NeighborArray neighbors = graph.getNeighbors(ord);
+ int size = neighbors.size();
+
+ // Destructively modify; it's ok we are discarding it after this
+ int[] nodes = neighbors.node();
+ Arrays.sort(nodes, 0, size);
+ graphData.writeInt(size);
+
+ int lastNode = -1; // to make the assertion work?
+ for (int i = 0; i < size; i++) {
+ int node = nodes[i];
+ assert node > lastNode : "nodes out of order: " + lastNode + "," + node;
+ graphData.writeVInt(node - lastNode);
+ lastNode = node;
}
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java
index d3ee0dc..3e83ab9 100644
--- a/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java
+++ b/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java
@@ -37,6 +37,11 @@ public abstract class KnnGraphValues {
public abstract void seek(int target) throws IOException;
/**
+ * Returns the number of nodes in the graph
+ */
+ public abstract int size();
+
+ /**
* Iterates over the neighbor list. It is illegal to call this method after it returns
* NO_MORE_DOCS without calling {@link #seek(int)}, which resets the iterator.
* @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
@@ -54,5 +59,10 @@ public abstract class KnnGraphValues {
@Override
public void seek(int target) {
}
+
+ @Override
+ public int size() {
+ return 0;
+ }
};
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/LongHeap.java b/lucene/core/src/java/org/apache/lucene/util/LongHeap.java
index 29a8f83..d3db500 100644
--- a/lucene/core/src/java/org/apache/lucene/util/LongHeap.java
+++ b/lucene/core/src/java/org/apache/lucene/util/LongHeap.java
@@ -19,13 +19,11 @@ package org.apache.lucene.util;
/**
* A heap that stores longs; a primitive priority queue that like all priority queues maintains a
* partial ordering of its elements such that the least element can always be found in constant
- * time. Put()'s and pop()'s require log(size). This heap may be bounded by constructing with a
- * finite maxSize, or enabled to grow dynamically by passing the constant UNBOUNDED for the maxSize.
+ * time. Put()'s and pop()'s require log(size). This heap provides unbounded growth via {@link #push(long)},
+ * and bounded-size insertion based on its nominal maxSize via {@link #insertWithOverflow(long)}.
* The heap may be either a min heap, in which case the least element is the smallest integer, or a
* max heap, when it is the largest, depending on the Order parameter.
*
- * <b>NOTE</b>: Iteration order is not specified.
- *
* @lucene.internal
*/
public abstract class LongHeap {
@@ -38,33 +36,26 @@ public abstract class LongHeap {
*/
public enum Order {
MIN, MAX
- };
+ }
- private static final int UNBOUNDED = -1;
private final int maxSize;
private long[] heap;
private int size = 0;
/**
- * Create an empty priority queue of the configured size.
+ * Create an empty priority queue of the configured initial size.
* @param maxSize the maximum size of the heap, or if negative, the initial size of an unbounded heap
*/
LongHeap(int maxSize) {
final int heapSize;
- if (maxSize < 0) {
- // initial size; this may grow
- heapSize = -maxSize;
- this.maxSize = UNBOUNDED;
- } else {
- if ((maxSize < 1) || (maxSize >= ArrayUtil.MAX_ARRAY_LENGTH)) {
- // Throw exception to prevent confusing OOME:
- throw new IllegalArgumentException("maxSize must be UNBOUNDED(-1) or > 0 and < " + (ArrayUtil.MAX_ARRAY_LENGTH) + "; got: " + maxSize);
- }
- // NOTE: we add +1 because all access to heap is 1-based not 0-based. heap[0] is unused.
- heapSize = maxSize + 1;
- this.maxSize = maxSize;
+ if (maxSize < 1 || maxSize >= ArrayUtil.MAX_ARRAY_LENGTH) {
+ // Throw exception to prevent confusing OOME:
+ throw new IllegalArgumentException("maxSize must be > 0 and < " + (ArrayUtil.MAX_ARRAY_LENGTH - 1) + "; got: " + maxSize);
}
+ // NOTE: we add +1 because all access to heap is 1-based not 0-based. heap[0] is unused.
+ heapSize = maxSize + 1;
+ this.maxSize = maxSize;
this.heap = new long[heapSize];
}
@@ -94,14 +85,13 @@ public abstract class LongHeap {
public abstract boolean lessThan(long a, long b);
/**
- * Adds a value in log(size) time. If one tries to add more values than maxSize from initialize an
- * {@link ArrayIndexOutOfBoundsException} is thrown, unless maxSize is {@link #UNBOUNDED}.
+ * Adds a value in log(size) time. Grows unbounded as needed to accommodate new values.
*
* @return the new 'top' element in the queue.
*/
public final long push(long element) {
size++;
- if (maxSize == UNBOUNDED && size == heap.length) {
+ if (size == heap.length) {
heap = ArrayUtil.grow(heap, (size * 3 + 1) / 2);
}
heap[size] = element;
@@ -110,22 +100,23 @@ public abstract class LongHeap {
}
/**
- * Adds a value to an IntHeap in log(size) time. if the number of values would exceed the heap's
+ * Adds a value to an LongHeap in log(size) time. If the number of values would exceed the heap's
* maxSize, the least value is discarded.
* @return whether the value was added (unless the heap is full, or the new value is less than the top value)
*/
public boolean insertWithOverflow(long value) {
- if (size < maxSize || maxSize == UNBOUNDED) {
- push(value);
- return true;
- } else if (size > 0 && !lessThan(value, heap[1])) {
+ if (size >= maxSize) {
+ if (lessThan(value, heap[1])) {
+ return false;
+ }
updateTop(value);
return true;
}
- return false;
+ push(value);
+ return true;
}
- /** Returns the least element of the IntHeap in constant time. It is up to the caller to verify
+ /** Returns the least element of the LongHeap in constant time. It is up to the caller to verify
* that the heap is not empty; no checking is done, and if no elements have been added, 0 is
* returned.
*/
@@ -134,7 +125,7 @@ public abstract class LongHeap {
}
/** Removes and returns the least element of the PriorityQueue in log(size) time.
- * @throws IllegalStateException if the IntHeap is empty.
+ * @throws IllegalStateException if the LongHeap is empty.
*/
public final long pop() {
if (size > 0) {
@@ -184,7 +175,7 @@ public abstract class LongHeap {
size = 0;
}
- private final void upHeap(int origPos) {
+ private void upHeap(int origPos) {
int i = origPos;
long value = heap[i]; // save bottom value
int j = i >>> 1;
@@ -196,7 +187,7 @@ public abstract class LongHeap {
heap[i] = value; // install saved value
}
- private final void downHeap(int i) {
+ private void downHeap(int i) {
long value = heap[i]; // save top value
int j = i << 1; // find smaller child
int k = j + 1;
@@ -215,26 +206,17 @@ public abstract class LongHeap {
heap[i] = value; // install saved value
}
- public LongIterator iterator() {
- return new LongIterator();
+ public void pushAll(LongHeap other) {
+ for (int i = 1; i <= other.size; i++) {
+ push(other.heap[i]);
+ }
}
- /**
- * Iterator over the contents of the heap, returning successive ints.
+ /** Return the element at the ith location in the heap array. Use for iterating over elements when the order doesn't matter.
+ * Note that the valid arguments range from [1, size].
*/
- public class LongIterator {
- int i = 1;
-
- public boolean hasNext() {
- return i <= size;
- }
-
- public long next() {
- if (hasNext() == false) {
- throw new IllegalStateException("attempt to iterate beyond maximum element, size=" + size);
- }
- return heap[i++];
- }
+ public long get(int i) {
+ return heap[i];
}
/** This method returns the internal heap array.
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java
index e02cc40..10572a88 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/BoundsChecker.java
@@ -27,7 +27,14 @@ abstract class BoundsChecker {
abstract void update(float sample);
/**
- * Return whether the sample exceeds (is worse than) the bound
+ * Update the bound unconditionally
+ */
+ void set(float sample) {
+ bound = sample;
+ }
+
+ /**
+ * @return whether the sample exceeds (is worse than) the bound
*/
abstract boolean check(float sample);
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
index 97c9175..d3546e2 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
@@ -19,14 +19,13 @@ package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.HashSet;
import java.util.List;
import java.util.Random;
-import java.util.Set;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.util.SparseFixedBitSet;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@@ -54,18 +53,24 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
* <p>Note: The graph may be searched by multiple threads concurrently, but updates are not thread-safe. Also note: there is no notion of
* deletions. Document searching built on top of this must do its own deletion-filtering.</p>
*/
-public final class HnswGraph {
+public final class HnswGraph extends KnnGraphValues {
private final int maxConn;
private final VectorValues.SearchStrategy searchStrategy;
// Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to HnswBuilder, and the
// node values are the ordinals of those vectors.
- private final List<Neighbors> graph;
+ private final List<NeighborArray> graph;
+
+ // KnnGraphValues iterator members
+ private int upto;
+ private NeighborArray cur;
HnswGraph(int maxConn, VectorValues.SearchStrategy searchStrategy) {
graph = new ArrayList<>();
- graph.add(Neighbors.create(maxConn, searchStrategy));
+ // Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
+ // about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
+ graph.add(new NeighborArray(Math.max(32, maxConn / 4)));
this.maxConn = maxConn;
this.searchStrategy = searchStrategy;
}
@@ -74,37 +79,40 @@ public final class HnswGraph {
* Searches for the nearest neighbors of a query vector.
* @param query search query vector
* @param topK the number of nodes to be returned
- * @param numSeed the number of random entry points to sample
+ * @param numSeed the size of the queue maintained while searching, and controls the number of random entry points to sample
* @param vectors vector values
* @param graphValues the graph values. May represent the entire graph, or a level in a hierarchical graph.
* @param random a source of randomness, used for generating entry points to the graph
- * @return a priority queue holding the neighbors found
+ * @return a priority queue holding the closest neighbors found
*/
- public static Neighbors search(float[] query, int topK, int numSeed, RandomAccessVectorValues vectors, KnnGraphValues graphValues,
- Random random) throws IOException {
+ public static NeighborQueue search(float[] query, int topK, int numSeed, RandomAccessVectorValues vectors, KnnGraphValues graphValues,
+ Random random) throws IOException {
VectorValues.SearchStrategy searchStrategy = vectors.searchStrategy();
+ int size = graphValues.size();
- Neighbors results = Neighbors.create(topK, searchStrategy);
- Neighbors candidates = Neighbors.createReversed(-numSeed, searchStrategy);
- // set of ordinals that have been visited by search on this layer, used to avoid backtracking
- Set<Integer> visited = new HashSet<>();
+ // MIN heap, holding the top results
+ NeighborQueue results = new NeighborQueue(numSeed, searchStrategy.reversed);
- int size = vectors.size();
+ // set of ordinals that have been visited by search on this layer, used to avoid backtracking
+ SparseFixedBitSet visited = new SparseFixedBitSet(size);
+ // get initial candidates at random
int boundedNumSeed = Math.min(numSeed, 2 * size);
for (int i = 0; i < boundedNumSeed; i++) {
int entryPoint = random.nextInt(size);
- if (visited.add(entryPoint)) {
- results.insertWithOverflow(entryPoint, searchStrategy.compare(query, vectors.vectorValue(entryPoint)));
+ if (visited.get(entryPoint) == false) {
+ visited.set(entryPoint);
+ // explore the topK starting points of some random numSeed probes
+ results.add(entryPoint, searchStrategy.compare(query, vectors.vectorValue(entryPoint)));
}
}
- Neighbors.NeighborIterator it = results.iterator();
- for (int nbr = it.next(); nbr != NO_MORE_DOCS; nbr = it.next()) {
- candidates.add(nbr, it.score());
- }
+
+ // MAX heap, from which to pull the candidate nodes
+ NeighborQueue candidates = results.copy(!searchStrategy.reversed);
+
// Set the bound to the worst current result and below reject any newly-generated candidates failing
// to exceed this bound
BoundsChecker bound = BoundsChecker.create(searchStrategy.reversed);
- bound.bound = results.topScore();
+ bound.set(results.topScore());
while (candidates.size() > 0) {
// get the best candidate (closest or best scoring)
float topCandidateScore = candidates.topScore();
@@ -117,87 +125,54 @@ public final class HnswGraph {
graphValues.seek(topCandidateNode);
int friendOrd;
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
- if (visited.contains(friendOrd)) {
+ if (visited.get(friendOrd)) {
continue;
}
- visited.add(friendOrd);
+ visited.set(friendOrd);
float score = searchStrategy.compare(query, vectors.vectorValue(friendOrd));
if (results.insertWithOverflow(friendOrd, score)) {
candidates.add(friendOrd, score);
- bound.bound = results.topScore();
+ bound.set(results.topScore());
}
}
}
- results.setVisitedCount(visited.size());
+ while (results.size() > topK) {
+ results.pop();
+ }
+ results.setVisitedCount(visited.approximateCardinality());
return results;
}
/**
- * Returns the {@link Neighbors} connected to the given node.
+ * Returns the {@link NeighborQueue} connected to the given node.
* @param node the node whose neighbors are returned
*/
- public Neighbors getNeighbors(int node) {
+ public NeighborArray getNeighbors(int node) {
return graph.get(node);
}
- public int[] getNeighborNodes(int node) {
- Neighbors neighbors = graph.get(node);
- int[] nodes = new int[neighbors.size()];
- Neighbors.NeighborIterator it = neighbors.iterator();
- for (int neighbor = it.next(), i = 0; neighbor != NO_MORE_DOCS; neighbor = it.next()) {
- nodes[i++] = neighbor;
- }
- return nodes;
- }
-
- /** Connects two nodes symmetrically, limiting the maximum number of connections from either node.
- * node1 must be less than node2 and must already have been inserted to the graph */
- void connectNodes(int node1, int node2, float score) {
- connect(node1, node2, score);
- if (node2 == graph.size()) {
- addNode();
- }
- connect(node2, node1, score);
- }
-
- KnnGraphValues getGraphValues() {
- return new HnswGraphValues();
- }
-
- /**
- * Makes a connection from the node to a neighbor, dropping the worst connection when maxConn is exceeded
- * @param node1 node to connect *from*
- * @param node2 node to connect *to*
- * @param score searchStrategy.score() of the vectors associated with the two nodes
- */
- boolean connect(int node1, int node2, float score) {
- //System.out.println(" HnswGraph.connect " + node1 + " -> " + node2);
- assert node1 >= 0 && node2 >= 0;
- return graph.get(node1)
- .insertWithOverflow(node2, score);
+ @Override
+ public int size() {
+ return graph.size();
}
int addNode() {
- graph.add(Neighbors.create(maxConn, searchStrategy));
+ graph.add(new NeighborArray(maxConn + 1));
return graph.size() - 1;
}
- /**
- * Present this graph as KnnGraphValues, used for searching while inserting new nodes.
- */
- private class HnswGraphValues extends KnnGraphValues {
-
- private Neighbors.NeighborIterator it;
-
- @Override
- public void seek(int targetNode) {
- it = HnswGraph.this.getNeighbors(targetNode).iterator();
- }
+ @Override
+ public void seek(int targetNode) {
+ cur = getNeighbors(targetNode);
+ upto = -1;
+ }
- @Override
- public int nextNeighbor() {
- return it.next();
+ @Override
+ public int nextNeighbor() {
+ if (++upto < cur.size()) {
+ return cur.node[upto];
}
+ return NO_MORE_DOCS;
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index e225c9b..250d0ba 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -21,15 +21,11 @@ import java.io.IOException;
import java.util.Locale;
import java.util.Random;
-import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
-import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.InfoStream;
-import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
-
/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the hyperparameters.
*/
@@ -42,48 +38,63 @@ public final class HnswGraphBuilder {
// expose for testing.
public static long randSeed = DEFAULT_RAND_SEED;
- // These "default" hyperparameter settings are exposed (and nonfinal) to enable performance testing
+ // These "default" hyper-parameter settings are exposed (and non-final) to enable performance testing
// since the indexing API doesn't provide any control over them.
// default max connections per node
public static int DEFAULT_MAX_CONN = 16;
// default candidate list size
- static int DEFAULT_BEAM_WIDTH = 16;
+ public static int DEFAULT_BEAM_WIDTH = 16;
private final int maxConn;
private final int beamWidth;
+ private final NeighborArray scratch;
- // TODO: how to pass this in?
- InfoStream infoStream = InfoStream.getDefault();
- // InfoStream infoStream = new PrintStreamInfoStream(System.out);
-
- private final BoundedVectorValues boundedVectors;
private final VectorValues.SearchStrategy searchStrategy;
- private final HnswGraph hnsw;
+ private final RandomAccessVectorValues vectorValues;
private final Random random;
+ private final BoundsChecker bound;
+ final HnswGraph hnsw;
- /**
- * Reads all the vectors from a VectorValues, builds a graph connecting them by their dense ordinals, using default
- * hyperparameter settings, and returns the resulting graph.
- * @param vectorValues the vectors whose relations are represented by the graph
- */
- public static HnswGraph build(RandomAccessVectorValuesProducer vectorValues) throws IOException {
- HnswGraphBuilder builder = new HnswGraphBuilder(vectorValues);
- return builder.build(vectorValues.randomAccess());
+ private InfoStream infoStream = InfoStream.getDefault();
+
+ // we need two sources of vectors in order to perform diversity check comparisons without colliding
+ private RandomAccessVectorValues buildVectors;
+
+ /** Construct the builder with default configurations */
+ public HnswGraphBuilder(RandomAccessVectorValuesProducer vectors) {
+ this(vectors, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, randSeed);
}
/**
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense ordinals, using the given
* hyperparameter settings, and returns the resulting graph.
- * @param vectorValues the vectors whose relations are represented by the graph
+ * @param vectors the vectors whose relations are represented by the graph - must provide a different view over those vectors
+ * than the one used to add via addGraphNode.
* @param maxConn the number of connections to make when adding a new graph node; roughly speaking the graph fanout.
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
* @param seed the seed for a random number generator used during graph construction. Provide this to ensure repeatable construction.
*/
- public static HnswGraph build(RandomAccessVectorValuesProducer vectorValues, int maxConn, int beamWidth, long seed) throws IOException {
- HnswGraphBuilder builder = new HnswGraphBuilder(vectorValues, maxConn, beamWidth, seed);
- return builder.build(vectorValues.randomAccess());
+ public HnswGraphBuilder(RandomAccessVectorValuesProducer vectors, int maxConn, int beamWidth, long seed) {
+ vectorValues = vectors.randomAccess();
+ buildVectors = vectors.randomAccess();
+ searchStrategy = vectorValues.searchStrategy();
+ if (searchStrategy == VectorValues.SearchStrategy.NONE) {
+ throw new IllegalStateException("No distance function");
+ }
+ if (maxConn <= 0) {
+ throw new IllegalArgumentException("maxConn must be positive");
+ }
+ if (beamWidth <= 0) {
+ throw new IllegalArgumentException("beamWidth must be positive");
+ }
+ this.maxConn = maxConn;
+ this.beamWidth = beamWidth;
+ this.hnsw = new HnswGraph(maxConn, searchStrategy);
+ bound = BoundsChecker.create(searchStrategy.reversed);
+ random = new Random(seed);
+ scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
}
/**
@@ -91,18 +102,19 @@ public final class HnswGraphBuilder {
* without extra data copying, while avoiding collision of the returned values.
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet accessor for the vectors
*/
- HnswGraph build(RandomAccessVectorValues vectors) throws IOException {
- if (vectors == boundedVectors.raDelegate) {
+ public HnswGraph build(RandomAccessVectorValues vectors) throws IOException {
+ if (vectors == vectorValues) {
throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
}
long start = System.nanoTime(), t = start;
+ // start at node 1! node 0 is added implicitly, in the constructor
for (int node = 1; node < vectors.size(); node++) {
- insert(vectors.vectorValue(node));
+ addGraphNode(vectors.vectorValue(node));
if (node % 10000 == 0) {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
long now = System.nanoTime();
infoStream.message(HNSW_COMPONENT,
- String.format(Locale.ROOT, "HNSW built %d in %d/%d ms", node, ((now - t) / 1_000_000), ((now - start) / 1_000_000)));
+ String.format(Locale.ROOT, "built %d in %d/%d ms", node, ((now - t) / 1_000_000), ((now - start) / 1_000_000)));
t = now;
}
}
@@ -110,105 +122,120 @@ public final class HnswGraphBuilder {
return hnsw;
}
- /** Construct the builder with default configurations */
- private HnswGraphBuilder(RandomAccessVectorValuesProducer vectors) {
- this(vectors, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, randSeed);
- }
-
- /** Full constructor */
- HnswGraphBuilder(RandomAccessVectorValuesProducer vectors, int maxConn, int beamWidth, long seed) {
- RandomAccessVectorValues vectorValues = vectors.randomAccess();
- searchStrategy = vectorValues.searchStrategy();
- if (searchStrategy == VectorValues.SearchStrategy.NONE) {
- throw new IllegalStateException("No distance function");
- }
- if (maxConn <= 0) {
- throw new IllegalArgumentException("maxConn must be positive");
- }
- if (beamWidth <= 0) {
- throw new IllegalArgumentException("beamWidth must be positive");
- }
- this.maxConn = maxConn;
- this.beamWidth = beamWidth;
- boundedVectors = new BoundedVectorValues(vectorValues);
- this.hnsw = new HnswGraph(maxConn, searchStrategy);
- random = new Random(seed);
+ public void setInfoStream(InfoStream infoStream) {
+ this.infoStream = infoStream;
}
/** Inserts a doc with vector value to the graph */
- private void insert(float[] value) throws IOException {
- addGraphNode(value);
-
- // add the vector value
- boundedVectors.inc();
- }
-
- private void addGraphNode(float[] value) throws IOException {
- KnnGraphValues graphValues = hnsw.getGraphValues();
- Neighbors candidates = HnswGraph.search(value, beamWidth, 2 * beamWidth, boundedVectors, graphValues, random);
+ void addGraphNode(float[] value) throws IOException {
+ NeighborQueue candidates = HnswGraph.search(value, beamWidth, beamWidth, vectorValues, hnsw, random);
int node = hnsw.addNode();
- // connect the nearest neighbors to the just inserted node
- addNearestNeighbors(node, candidates);
+ // connect neighbors to the new node, using a diversity heuristic that chooses successive
+ // nearest neighbors that are closer to the new node than they are to the previously-selected
+ // neighbors
+ addDiverseNeighbors(node, candidates, buildVectors);
}
- private void addNearestNeighbors(int newNode, Neighbors neighbors) {
- // connect the nearest neighbors, relying on the graph's Neighbors' priority queues to drop off distant neighbors
- Neighbors.NeighborIterator it = neighbors.iterator();
- for (int node = it.next(); node != NO_MORE_DOCS; node = it.next()) {
- float score = it.score();
- if (hnsw.connect(newNode, node, score)) {
- hnsw.connect(node, newNode, score);
+ private void addDiverseNeighbors(int node, NeighborQueue candidates, RandomAccessVectorValues vectors) throws IOException {
+ // For each of the beamWidth nearest candidates (going from best to worst), select it only if it is closer to target
+ // than it is to any of the already-selected neighbors (ie selected in this method, since the node is new and has no
+ // prior neighbors).
+ NeighborArray neighbors = hnsw.getNeighbors(node);
+ assert neighbors.size() == 0; // new node
+ popToScratch(candidates);
+ selectDiverse(neighbors, scratch, vectors);
+
+ // Link the selected nodes to the new node, and the new node to the selected nodes (again applying diversity heuristic)
+ int size = neighbors.size();
+ for (int i = 0; i < size; i++) {
+ int nbr = neighbors.node[i];
+ NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
+ nbrNbr.add(node, neighbors.score[i]);
+ if (nbrNbr.size() > maxConn) {
+ diversityUpdate(nbrNbr, buildVectors);
}
}
}
- /**
- * Provides a random access VectorValues view over a delegate VectorValues, bounding the maximum ord.
- * TODO: get rid of this, all it does is track a counter
- */
- private static class BoundedVectorValues implements RandomAccessVectorValues {
-
- final RandomAccessVectorValues raDelegate;
-
- int size;
-
- BoundedVectorValues(RandomAccessVectorValues delegate) {
- raDelegate = delegate;
- if (delegate.size() > 0) {
- // we implicitly add the first node
- size = 1;
+ private void selectDiverse(NeighborArray neighbors, NeighborArray candidates, RandomAccessVectorValues vectors) throws IOException {
+ // Select the best maxConn neighbors of the new node, applying the diversity heuristic
+ for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
+ // compare each neighbor (in distance order) against the closer neighbors selected so far,
+ // only adding it if it is closer to the target than to any of the other selected neighbors
+ int cNode = candidates.node[i];
+ float cScore = candidates.score[i];
+ if (diversityCheck(vectors.vectorValue(cNode), cScore, neighbors, buildVectors)) {
+ neighbors.add(cNode, cScore);
}
}
+ }
- void inc() {
- ++size;
- }
-
- @Override
- public int size() {
- return size;
+ private void popToScratch(NeighborQueue candidates) {
+ scratch.clear();
+ int candidateCount = candidates.size();
+ // extract all the Neighbors from the queue into an array; these will now be
+ // sorted from worst to best
+ for (int i = 0; i < candidateCount; i++) {
+ float score = candidates.topScore();
+ scratch.add(candidates.pop(), score);
}
+ }
- @Override
- public int dimension() { return raDelegate.dimension(); }
-
- @Override
- public VectorValues.SearchStrategy searchStrategy() {
- return raDelegate.searchStrategy();
+ /**
+ * @param candidate the vector of a new candidate neighbor of a node n
+ * @param score the score of the new candidate and node n, to be compared with scores of the candidate and n's neighbors
+ * @param neighbors the neighbors selected so far
+ * @param vectorValues source of values used for making comparisons between candidate and existing neighbors
+ * @return whether the candidate is diverse given the existing neighbors
+ */
+ private boolean diversityCheck(float[] candidate, float score, NeighborArray neighbors, RandomAccessVectorValues vectorValues) throws IOException {
+ bound.set(score);
+ for (int i = 0; i < neighbors.size(); i++) {
+ float diversityCheck = searchStrategy.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
+ if (bound.check(diversityCheck) == false) {
+ return false;
+ }
}
+ return true;
+ }
- @Override
- public float[] vectorValue(int target) throws IOException {
- return raDelegate.vectorValue(target);
+ private void diversityUpdate(NeighborArray neighbors, RandomAccessVectorValues vectorValues) throws IOException {
+ assert neighbors.size() == maxConn + 1;
+ int replacePoint = findNonDiverse(neighbors, vectorValues);
+ if (replacePoint == -1) {
+ // none found; check score against worst existing neighbor
+ bound.set(neighbors.score[0]);
+ if (bound.check(neighbors.score[maxConn])) {
+ // drop the new neighbor; it is not competitive and there were no diversity failures
+ neighbors.removeLast();
+ return;
+ } else {
+ replacePoint = 0;
+ }
}
+ neighbors.node[replacePoint] = neighbors.node[maxConn];
+ neighbors.score[replacePoint] = neighbors.score[maxConn];
+ neighbors.removeLast();
+ }
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- throw new UnsupportedOperationException();
+ // scan neighbors looking for diversity violations
+ private int findNonDiverse(NeighborArray neighbors, RandomAccessVectorValues vectorValues) throws IOException {
+ for (int i = neighbors.size() - 1; i >= 0; i--) {
+ // check each neighbor against its better-scoring neighbors. If it fails diversity check with them, drop it
+ int nbrNode = neighbors.node[i];
+ bound.set(neighbors.score[i]);
+ float[] nbrVector = vectorValues.vectorValue(nbrNode);
+ for (int j = maxConn; j > i; j--) {
+ float diversityCheck = searchStrategy.compare(nbrVector, vectorValues.vectorValue(neighbors.node[j]));
+ if (bound.check(diversityCheck) == false) {
+ // node j is too similar to node i given its score relative to the base node
+ // replace it with the new node, which is at [maxConn]
+ return i;
+ }
+ }
}
+ return -1;
}
-
-
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
new file mode 100644
index 0000000..ac7a9a1
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.hnsw;
+
+import org.apache.lucene.util.ArrayUtil;
+
+/** NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair of growable arrays.
+ * @lucene.internal
+ */
+public class NeighborArray {
+
+ private int size;
+ private int upto;
+
+ float[] score;
+ int[] node;
+
+ NeighborArray(int maxSize) {
+ node = new int[maxSize];
+ score = new float[maxSize];
+ }
+
+ public void add(int newNode, float newScore) {
+ if (size == node.length - 1) {
+ node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
+ score = ArrayUtil.growExact(score, node.length);
+ }
+ node[size] = newNode;
+ score[size] = newScore;
+ ++size;
+ }
+
+ public int size() {
+ return size;
+ }
+
+ /** Direct access to the internal list of node ids; provided for efficient writing of the graph
+ * @lucene.internal
+ */
+ public int[] node() {
+ return node;
+ }
+
+ public void clear() {
+ size = 0;
+ }
+
+ void removeLast() {
+ size--;
+ }
+
+ @Override
+ public String toString() {
+ return "NeighborArray[" + size + "]";
+ }
+
+}
+
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
similarity index 52%
rename from lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java
rename to lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
index d193b38..9a2d67f 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
@@ -17,52 +17,59 @@
package org.apache.lucene.util.hnsw;
-import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.LongHeap;
import org.apache.lucene.util.NumericUtils;
-import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
-
-/** Neighbors encodes the neighbors of a node in the HNSW graph. */
-public class Neighbors {
-
- private static final int INITIAL_SIZE = 128;
-
- public static Neighbors create(int maxSize, VectorValues.SearchStrategy searchStrategy) {
- return new Neighbors(maxSize, searchStrategy, searchStrategy.reversed);
- }
-
- public static Neighbors createReversed(int maxSize, VectorValues.SearchStrategy searchStrategy) {
- return new Neighbors(maxSize, searchStrategy, !searchStrategy.reversed);
- }
+/** NeighborQueue uses a {@link LongHeap} to store lists of arcs in an HNSW graph, represented as a neighbor node
+ * id with an associated score packed together as a sortable long, which is sorted primarily by score. The queue
+ * provides both fixed-size and unbounded operations via {@link #insertWithOverflow(int, float)} and {@link #add(int, float)},
+ * and provides MIN and MAX heap subclasses.
+ */
+public class NeighborQueue {
private final LongHeap heap;
- private final VectorValues.SearchStrategy searchStrategy;
// Used to track the number of neighbors visited during a single graph traversal
private int visitedCount;
- private Neighbors(int maxSize, VectorValues.SearchStrategy searchStrategy, boolean reversed) {
- this.searchStrategy = searchStrategy;
+ NeighborQueue(int initialSize, boolean reversed) {
if (reversed) {
- heap = LongHeap.create(LongHeap.Order.MAX, maxSize);
+ heap = LongHeap.create(LongHeap.Order.MAX, initialSize);
} else {
- heap = LongHeap.create(LongHeap.Order.MIN, maxSize);
+ heap = LongHeap.create(LongHeap.Order.MIN, initialSize);
}
}
- public int size() {
- return heap.size();
+ NeighborQueue copy(boolean reversed) {
+ int size = size();
+ NeighborQueue copy = new NeighborQueue(size, reversed);
+ copy.heap.pushAll(heap);
+ return copy;
}
- public boolean reversed() {
- return searchStrategy.reversed;
+ /**
+ * @return the number of elements in the heap
+ */
+ public int size() {
+ return heap.size();
}
+ /**
+ * Adds a new graph arc, extending the storage as needed.
+ * @param newNode the neighbor node id
+ * @param newScore the score of the neighbor, relative to some other node
+ */
public void add(int newNode, float newScore) {
heap.push(encode(newNode, newScore));
}
+ /**
+ * If the heap is not full (size is less than the initialSize provided to the constructor), adds a new node-and-score element.
+ * If the heap is full, compares the score against the current top score, and replaces the top element if newScore is better than
+ * (greater than unless the heap is reversed), the current top score.
+ * @param newNode the neighbor node id
+ * @param newScore the score of the neighbor, relative to some other node
+ */
public boolean insertWithOverflow(int newNode, float newScore) {
return heap.insertWithOverflow(encode(newNode, newScore));
}
@@ -71,52 +78,46 @@ public class Neighbors {
return (((long) NumericUtils.floatToSortableInt(score)) << 32) | node;
}
+ /**
+ * Removes the top element and returns its node id.
+ */
public int pop() {
return (int) heap.pop();
}
+ int[] nodes() {
+ int size = size();
+ int[] nodes = new int[size];
+ for (int i = 0; i < size; i++) {
+ nodes[i] = (int) heap.get(i + 1);
+ }
+ return nodes;
+ }
+
+ /**
+ * Returns the top element's node id.
+ */
public int topNode() {
return (int) heap.top();
}
+ /**
+ * Returns the top element's node score.
+ */
public float topScore() {
return NumericUtils.sortableIntToFloat((int) (heap.top() >> 32));
}
- void setVisitedCount(int visitedCount) {
- this.visitedCount = visitedCount;
- }
-
public int visitedCount() {
return visitedCount;
}
- public NeighborIterator iterator() {
- return new NeighborIterator();
- }
-
- class NeighborIterator {
- private long value;
- private final LongHeap.LongIterator heapIterator = heap.iterator();
-
- /** Return the next node */
- public int next() {
- if (heapIterator.hasNext()) {
- value = heapIterator.next();
- return (int) value;
- }
- return NO_MORE_DOCS;
- }
-
- /** Return the score corresponding to the last node returned by next() */
- public float score() {
- return NumericUtils.sortableIntToFloat((int) (value >> 32));
- }
+ void setVisitedCount(int visitedCount) {
+ this.visitedCount = visitedCount;
}
@Override
public String toString() {
return "Neighbors[" + heap.size() + "]";
}
-
}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index 35b4dfa..b8de53b 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -56,7 +56,7 @@ public class TestKnnGraph extends LuceneTestCase {
randSeed = random().nextLong();
if (random().nextBoolean()) {
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
- HnswGraphBuilder.DEFAULT_MAX_CONN = random().nextInt(1000) + 1;
+ HnswGraphBuilder.DEFAULT_MAX_CONN = random().nextInt(256) + 1;
}
}
@@ -148,8 +148,7 @@ public class TestKnnGraph extends LuceneTestCase {
*/
public void testSearch() throws Exception {
try (Directory dir = newDirectory();
- // don't allow random merges; they mess up the docid tie-breaking assertion
- IndexWriter iw = new IndexWriter(dir, new IndexWriterConfig().setCodec(Codec.forName("Lucene90")))) {
+ IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) {
// Add a document for every cartesian point in an NxN square so we can
// easily know which are the nearest neighbors to every point. Insert by iterating
// using a prime number that is not a divisor of N*N so that we will hit each point once,
@@ -294,7 +293,8 @@ public class TestKnnGraph extends LuceneTestCase {
}
if (HnswGraphBuilder.DEFAULT_MAX_CONN > graphSize) {
// assert that the graph in each leaf is connected and undirected (ie links are reciprocated)
- assertReciprocal(graph);
+ // We cannot assert this when diversity criterion is applied
+ // assertReciprocal(graph);
assertConnected(graph);
} else {
// assert that max-connections was respected
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java b/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java
index 9f34811..bdc55c0 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java
@@ -18,7 +18,6 @@ package org.apache.lucene.util;
import java.util.ArrayList;
-import java.util.List;
import java.util.Random;
import static org.apache.lucene.util.LongHeap.Order.MAX;
@@ -27,7 +26,7 @@ import static org.apache.lucene.util.LongHeap.Order.MIN;
public class TestLongHeap extends LuceneTestCase {
private static class AssertingLongHeap extends LongHeap {
- public AssertingLongHeap(int count) {
+ AssertingLongHeap(int count) {
super(count);
}
@@ -36,7 +35,7 @@ public class TestLongHeap extends LuceneTestCase {
return (a < b);
}
- protected final void checkValidity() {
+ final void checkValidity() {
long[] heapArray = getHeapArray();
for (int i = 1; i <= size(); i++) {
int parent = i >>> 1;
@@ -87,8 +86,10 @@ public class TestLongHeap extends LuceneTestCase {
public void testExceedBounds() {
LongHeap pq = LongHeap.create(MIN, 1);
pq.push(2);
- expectThrows(ArrayIndexOutOfBoundsException.class, () -> pq.push(0));
- assertEquals(2, pq.size()); // the heap is unusable at this point
+ pq.push(0);
+ //expectThrows(ArrayIndexOutOfBoundsException.class, () -> pq.push(0));
+ assertEquals(2, pq.size()); // the heap has been extended to a new max size
+ assertEquals(0, pq.top());
}
public void testFixedSize() {
@@ -150,99 +151,45 @@ public class TestLongHeap extends LuceneTestCase {
lastLeast = newLeast;
}
}
-
- public void testIteratorEmpty() {
- LongHeap queue = LongHeap.create(MIN, 3);
- LongHeap.LongIterator it = queue.iterator();
- assertFalse(it.hasNext());
- expectThrows(IllegalStateException.class, () -> {
- it.next();
- });
- }
-
- public void testIteratorOne() {
- LongHeap queue = LongHeap.create(MIN, 3);
-
- queue.push(1);
- LongHeap.LongIterator it = queue.iterator();
- assertTrue(it.hasNext());
- assertEquals(1, it.next());
- assertFalse(it.hasNext());
- expectThrows(IllegalStateException.class, () -> {
- it.next();
- });
- }
-
- public void testIteratorTwo() {
- LongHeap queue = LongHeap.create(MIN, 3);
-
- queue.push(1);
- queue.push(2);
- LongHeap.LongIterator it = queue.iterator();
- assertTrue(it.hasNext());
- assertEquals(1, it.next());
- assertTrue(it.hasNext());
- assertEquals(2, it.next());
- assertFalse(it.hasNext());
- expectThrows(IllegalStateException.class, () -> {
- it.next();
- });
- }
-
- public void testIteratorRandom() {
- LongHeap.Order order;
- if (random().nextBoolean()) {
- order = MIN;
- } else {
- order = MAX;
- }
- final int maxSize = TestUtil.nextInt(random(), 1, 20);
- LongHeap queue = LongHeap.create(order, maxSize);
- final int iters = atLeast(100);
- final List<Long> expected = new ArrayList<>();
- for (int iter = 0; iter < iters; ++iter) {
- if (queue.size() == 0 || (queue.size() < maxSize && random().nextBoolean())) {
- final long value = random().nextInt(10);
- queue.push(value);
- expected.add(value);
- } else {
- expected.remove(Long.valueOf(queue.pop()));
- }
- List<Long> actual = new ArrayList<>();
- LongHeap.LongIterator it = queue.iterator();
- while (it.hasNext()) {
- actual.add(it.next());
- }
- CollectionUtil.introSort(expected);
- CollectionUtil.introSort(actual);
- assertEquals(expected, actual);
- }
+ public void testInvalid() {
+ expectThrows(IllegalArgumentException.class, () -> LongHeap.create(MAX, -1));
+ expectThrows(IllegalArgumentException.class, () -> LongHeap.create(MAX, 0));
+ expectThrows(IllegalArgumentException.class, () -> LongHeap.create(MAX, ArrayUtil.MAX_ARRAY_LENGTH));
}
public void testUnbounded() {
- LongHeap pq = LongHeap.create(MAX, -1);
+ int initialSize = random().nextInt(10) + 1;
+ LongHeap pq = LongHeap.create(MAX, initialSize);
int num = random().nextInt(100) + 1;
- long maxValue = Long.MIN_VALUE;
+ long minValue = Long.MAX_VALUE;
+ int count = 0;
for (int i = 0; i < num; i++) {
long value = random().nextLong();
if (random().nextBoolean()) {
pq.push(value);
+ count++;
} else {
- pq.insertWithOverflow(value);
+ boolean full = pq.size() >= initialSize;
+ if (pq.insertWithOverflow(value)) {
+ if (full == false) {
+ count++;
+ }
+ }
}
- maxValue = Math.max(maxValue, value);
+ minValue = Math.min(minValue, value);
}
- assertEquals(num, pq.size());
- assertEquals(maxValue, pq.top());
- long last = maxValue;
- int count = 0;
+ assertEquals(count, pq.size());
+ long last = Long.MAX_VALUE;
while (pq.size() > 0) {
+ long top = pq.top();
long next = pq.pop();
- ++ count;
+ assertEquals(top, next);
+ -- count;
assertTrue(next <= last);
last = next;
}
- assertEquals(num, count);
+ assertEquals(0, count);
+ assertEquals(minValue, last);
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestNumericUtils.java b/lucene/core/src/test/org/apache/lucene/util/TestNumericUtils.java
index 08567a5..7dbbfd2 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestNumericUtils.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestNumericUtils.java
@@ -488,5 +488,4 @@ public class TestNumericUtils extends LuceneTestCase {
Integer.signum(left.compareTo(right)));
}
}
-
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index 9716f4b..6be49f9 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -17,6 +17,7 @@
package org.apache.lucene.util.hnsw;
+import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.management.ManagementFactory;
@@ -45,11 +46,15 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.RandomAccessVectorValues;
+import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.PrintStreamInfoStream;
import org.apache.lucene.util.SuppressForbidden;
@@ -68,6 +73,7 @@ public class KnnGraphTester {
private int numDocs;
private int dim;
private int topK;
+ private int warmCount;
private int numIters;
private int fanout;
private Path indexPath;
@@ -82,8 +88,8 @@ public class KnnGraphTester {
numIters = 1000;
dim = 256;
topK = 100;
+ warmCount = 1000;
fanout = topK;
- indexPath = Paths.get("knn_test_index");
}
public static void main(String... args) throws Exception {
@@ -91,24 +97,24 @@ public class KnnGraphTester {
}
private void run(String... args) throws Exception {
- String operation = null, docVectorsPath = null, queryPath = null;
+ String operation = null;
+ Path docVectorsPath = null, queryPath = null, outputPath = null;
for (int iarg = 0; iarg < args.length; iarg++) {
String arg = args[iarg];
switch(arg) {
- case "-generate":
case "-search":
case "-check":
case "-stats":
+ case "-dump":
if (operation != null) {
throw new IllegalArgumentException("Specify only one operation, not both " + arg + " and " + operation);
}
- if (iarg == args.length - 1) {
- throw new IllegalArgumentException("Operation " + arg + " requires a following pathname");
- }
operation = arg;
- docVectorsPath = args[++iarg];
if (operation.equals("-search")) {
- queryPath = args[++iarg];
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException("Operation " + arg + " requires a following pathname");
+ }
+ queryPath = Paths.get(args[++iarg]);
}
break;
case "-fanout":
@@ -150,6 +156,21 @@ public class KnnGraphTester {
case "-reindex":
reindex = true;
break;
+ case "-topK":
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException("-topK requires a following number");
+ }
+ topK = Integer.parseInt(args[++iarg]);
+ break;
+ case "-out":
+ outputPath = Paths.get(args[++iarg]);
+ break;
+ case "-warm":
+ warmCount = Integer.parseInt(args[++iarg]);
+ break;
+ case "-docs":
+ docVectorsPath = Paths.get(args[++iarg]);
+ break;
case "-forceMerge":
operation = "-forceMerge";
break;
@@ -161,31 +182,48 @@ public class KnnGraphTester {
//usage();
}
}
- if (operation == null) {
+ if (operation == null && reindex == false) {
usage();
}
+ indexPath = Paths.get(formatIndexPath(docVectorsPath));
if (reindex) {
if (docVectorsPath == null) {
throw new IllegalArgumentException("-docs argument is required when indexing");
}
- reindexTimeMsec = createIndex(Paths.get(docVectorsPath), indexPath);
+ reindexTimeMsec = createIndex(docVectorsPath, indexPath);
}
- switch (operation) {
- case "-search":
- if (docVectorsPath == null) {
- throw new IllegalArgumentException("-docs argument is required when searching");
- }
- testSearch(indexPath, Paths.get(queryPath), getNN(Paths.get(docVectorsPath), Paths.get(queryPath)));
- break;
- case "-forceMerge":
- forceMerge();
- break;
- case "-stats":
- printFanoutHist(indexPath);
- break;
+ if (operation != null) {
+ switch (operation) {
+ case "-search":
+ if (docVectorsPath == null) {
+ throw new IllegalArgumentException("missing -docs arg");
+ }
+ if (outputPath != null) {
+ testSearch(indexPath, queryPath, outputPath, null);
+ } else {
+ testSearch(indexPath, queryPath, null, getNN(docVectorsPath, queryPath));
+ }
+ break;
+ case "-forceMerge":
+ forceMerge();
+ break;
+ case "-dump":
+ dumpGraph(docVectorsPath);
+ break;
+ case "-stats":
+ printFanoutHist(indexPath);
+ break;
+ }
}
}
+ private String formatIndexPath(Path docsPath) {
+ return docsPath.getFileName() +
+ "-" + HnswGraphBuilder.DEFAULT_MAX_CONN
+ + "-" + HnswGraphBuilder.DEFAULT_BEAM_WIDTH
+ + ".index";
+ }
+
@SuppressForbidden(reason="Prints stuff")
private void printFanoutHist(Path indexPath) throws IOException {
try (Directory dir = FSDirectory.open(indexPath);
@@ -200,6 +238,37 @@ public class KnnGraphTester {
}
}
+ private void dumpGraph(Path docsPath) throws IOException {
+ try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
+ RandomAccessVectorValues values = vectors.randomAccess();
+ HnswGraphBuilder builder = new HnswGraphBuilder(vectors, HnswGraphBuilder.DEFAULT_MAX_CONN, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, 0);
+ // start at node 1
+ for (int i = 1; i < numDocs; i++) {
+ builder.addGraphNode(values.vectorValue(i));
+ System.out.println("\nITERATION " + i);
+ dumpGraph(builder.hnsw);
+ }
+ }
+ }
+
+ private void dumpGraph(HnswGraph hnsw) {
+ for (int i = 0; i < hnsw.size(); i++) {
+ NeighborArray neighbors = hnsw.getNeighbors(i);
+ System.out.printf(Locale.ROOT, "%5d", i);
+ NeighborArray sorted = new NeighborArray(neighbors.size());
+ for (int j = 0; j < neighbors.size(); j++) {
+ int node = neighbors.node[j];
+ float score = neighbors.score[j];
+ sorted.add(node, score);
+ }
+ new NeighborArraySorter(sorted).sort(0, sorted.size());
+ for (int j = 0; j < sorted.size(); j++) {
+ System.out.printf(Locale.ROOT, " [%d, %.4f]", sorted.node[j], sorted.score[j]);
+ }
+ System.out.println();
+ }
+ }
+
@SuppressForbidden(reason="Prints stuff")
private void forceMerge() throws IOException {
IndexWriterConfig iwc = new IndexWriterConfig()
@@ -253,7 +322,7 @@ public class KnnGraphTester {
}
@SuppressForbidden(reason="Prints stuff")
- private void testSearch(Path indexPath, Path queryPath, int[][] nn) throws IOException {
+ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] nn) throws IOException {
TopDocs[] results = new TopDocs[numIters];
long elapsed, totalCpuTime, totalVisited = 0;
try (FileChannel q = FileChannel.open(queryPath)) {
@@ -269,8 +338,8 @@ public class KnnGraphTester {
long cpuTimeStartNs;
try (Directory dir = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(dir)) {
-
- for (int i = 0; i < 1000; i++) {
+ numDocs = reader.maxDoc();
+ for (int i = 0; i < warmCount; i++) {
// warm up
targets.get(target);
results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
@@ -296,14 +365,28 @@ public class KnnGraphTester {
+ "CPU time=" + totalCpuTime + "ms");
}
}
- if (quiet == false) {
- System.out.println("checking results");
- }
- float recall = checkResults(results, nn);
- totalVisited /= numIters;
- if (quiet) {
- System.out.printf(Locale.ROOT, "%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n", recall, totalCpuTime / (float) numIters,
- numDocs, fanout, HnswGraphBuilder.DEFAULT_MAX_CONN, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, totalVisited, reindexTimeMsec);
+ if (outputPath != null) {
+ ByteBuffer buf = ByteBuffer.allocate(4);
+ IntBuffer ibuf = buf.order(ByteOrder.LITTLE_ENDIAN).asIntBuffer();
+ try (OutputStream out = Files.newOutputStream(outputPath)) {
+ for (int i = 0; i < numIters; i++) {
+ for (ScoreDoc doc : results[i].scoreDocs) {
+ ibuf.position(0);
+ ibuf.put(doc.doc);
+ out.write(buf.array());
+ }
+ }
+ }
+ } else {
+ if (quiet == false) {
+ System.out.println("checking results");
+ }
+ float recall = checkResults(results, nn);
+ totalVisited /= numIters;
+ if (quiet) {
+ System.out.printf(Locale.ROOT, "%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n", recall, totalCpuTime / (float) numIters,
+ numDocs, fanout, HnswGraphBuilder.DEFAULT_MAX_CONN, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, totalVisited, reindexTimeMsec);
+ }
}
}
@@ -426,7 +509,7 @@ public class KnnGraphTester {
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
- Neighbors queue = Neighbors.create(topK, SEARCH_STRATEGY);
+ NeighborQueue queue = new NeighborQueue(topK, SEARCH_STRATEGY.reversed);
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = SEARCH_STRATEGY.compare(query, vector);
@@ -492,9 +575,104 @@ public class KnnGraphTester {
}
private static void usage() {
- String error = "Usage: TestKnnGraph -generate|-search|-stats|-check {datafile} [-beamWidth N]";
+ String error = "Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N]";
System.err.println(error);
System.exit(1);
}
+ class BinaryFileVectors implements RandomAccessVectorValuesProducer, Closeable {
+
+ private final int size;
+ private final FileChannel in;
+ private final FloatBuffer mmap;
+
+ BinaryFileVectors(Path filePath) throws IOException {
+ in = FileChannel.open(filePath);
+ long totalBytes = (long) numDocs * dim * Float.BYTES;
+ if (totalBytes > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("input over 2GB not supported");
+ }
+ int vectorByteSize = dim * Float.BYTES;
+ size = (int) (totalBytes / vectorByteSize);
+ mmap = in.map(FileChannel.MapMode.READ_ONLY, 0, totalBytes)
+ .order(ByteOrder.LITTLE_ENDIAN)
+ .asFloatBuffer();
+ }
+
+ @Override
+ public void close() throws IOException {
+ in.close();
+ }
+
+ @Override
+ public RandomAccessVectorValues randomAccess() {
+ return new Values();
+ }
+
+ class Values implements RandomAccessVectorValues {
+
+ float[] vector = new float[dim];
+ FloatBuffer source = mmap.slice();
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public int dimension() {
+ return dim;
+ }
+
+ @Override
+ public VectorValues.SearchStrategy searchStrategy() {
+ return SEARCH_STRATEGY;
+ }
+
+ @Override
+ public float[] vectorValue(int targetOrd) {
+ int pos = targetOrd * dim;
+ source.position(pos);
+ source.get(vector);
+ return vector;
+ }
+
+ @Override
+ public BytesRef binaryValue(int targetOrd) {
+ throw new UnsupportedOperationException();
+ }
+ }
+ }
+
+ static class NeighborArraySorter extends IntroSorter {
+ private final int[] node;
+ private final float[] score;
+
+ NeighborArraySorter(NeighborArray neighbors) {
+ node = neighbors.node;
+ score = neighbors.score;
+ }
+
+ int pivot;
+
+ @Override
+ protected void swap(int i, int j) {
+ int tmpNode = node[i];
+ float tmpScore = score[i];
+ node[i] = node[j];
+ score[i] = score[j];
+ node[j] = tmpNode;
+ score[j] = tmpScore;
+ }
+
+ @Override
+ protected void setPivot(int i) {
+ pivot = i;
+ }
+
+ @Override
+ protected int comparePivot(int j) {
+ return Float.compare(score[pivot], score[j]);
+ }
+ }
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java
new file mode 100644
index 0000000..7616cf8
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.hnsw;
+
+import org.apache.lucene.index.RandomAccessVectorValues;
+import org.apache.lucene.index.RandomAccessVectorValuesProducer;
+import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.LuceneTestCase;
+
+class MockVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
+ private final float[] scratch;
+
+ protected final int dimension;
+ protected final float[][] denseValues;
+ protected final float[][] values;
+ protected final SearchStrategy searchStrategy;
+ private final int numVectors;
+
+ private int pos = -1;
+
+ MockVectorValues(SearchStrategy searchStrategy, float[][] values) {
+ this.searchStrategy = searchStrategy;
+ this.dimension = values[0].length;
+ this.values = values;
+ int maxDoc = values.length;
+ denseValues = new float[maxDoc][];
+ int count = 0;
+ for (int i = 0; i < maxDoc; i++) {
+ if (values[i] != null) {
+ denseValues[count++] = values[i];
+ }
+ }
+ numVectors = count;
+ scratch = new float[dimension];
+ }
+
+ public MockVectorValues copy() {
+ return new MockVectorValues(searchStrategy, values);
+ }
+
+ @Override
+ public int size() {
+ return numVectors;
+ }
+
+ @Override
+ public SearchStrategy searchStrategy() {
+ return searchStrategy;
+ }
+
+ @Override
+ public int dimension() {
+ return dimension;
+ }
+
+ @Override
+ public float[] vectorValue() {
+ if (LuceneTestCase.random().nextBoolean()) {
+ return values[pos];
+ } else {
+ // Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
+ // This should help us catch cases of aliasing where the same VectorValues source is used twice in a
+ // single computation.
+ System.arraycopy(values[pos], 0, scratch, 0, dimension);
+ return scratch;
+ }
+ }
+
+ @Override
+ public RandomAccessVectorValues randomAccess() {
+ return copy();
+ }
+
+ @Override
+ public float[] vectorValue(int targetOrd) {
+ return denseValues[targetOrd];
+ }
+
+ @Override
+ public BytesRef binaryValue(int targetOrd) {
+ return null;
+ }
+
+ @Override
+ public TopDocs search(float[] target, int k, int fanout) {
+ return null;
+ }
+
+ private boolean seek(int target) {
+ if (target >= 0 && target < values.length && values[target] != null) {
+ pos = target;
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public int docID() {
+ return pos;
+ }
+
+ @Override
+ public int nextDoc() {
+ return advance(pos + 1);
+ }
+
+ public int advance(int target) {
+ while (++pos < values.length) {
+ if (seek(pos)) {
+ return pos;
+ }
+ }
+ return NO_MORE_DOCS;
+ }
+
+ @Override
+ public long cost() {
+ return size();
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
index f41bc31..a97d96b 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
@@ -18,6 +18,7 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
+import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
@@ -39,6 +40,7 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
@@ -56,7 +58,8 @@ public class TestHnsw extends LuceneTestCase {
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
long seed = random().nextLong();
HnswGraphBuilder.randSeed = seed;
- HnswGraph hnsw = HnswGraphBuilder.build((RandomAccessVectorValuesProducer) vectors);
+ HnswGraphBuilder builder = new HnswGraphBuilder(vectors);
+ HnswGraph hnsw = builder.build(vectors.randomAccess());
// Recreate the graph while indexing with the same random seed and write it out
HnswGraphBuilder.randSeed = seed;
try (Directory dir = newDirectory()) {
@@ -89,7 +92,7 @@ public class TestHnsw extends LuceneTestCase {
assertEquals(indexedDoc, ctx.reader().numDocs());
assertVectorsEqual(v3, values);
KnnGraphValues graphValues = ((Lucene90VectorReader) ((CodecReader) ctx.reader()).getVectorReader()).getGraphValues("field");
- assertGraphEqual(hnsw.getGraphValues(), graphValues, nVec);
+ assertGraphEqual(hnsw, graphValues, nVec);
}
}
}
@@ -98,55 +101,165 @@ public class TestHnsw extends LuceneTestCase {
// Make sure we actually approximately find the closest k elements. Mostly this is about
// ensuring that we have all the distance functions, comparators, priority queues and so on
// oriented in the right directions
- public void testAknn() throws IOException {
+ public void testAknnDiverse() throws IOException {
int nDoc = 100;
RandomAccessVectorValuesProducer vectors = new CircularVectorValues(nDoc);
- HnswGraph hnsw = HnswGraphBuilder.build(vectors);
+ HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 16, 100, random().nextInt());
+ HnswGraph hnsw = builder.build(vectors.randomAccess());
// run some searches
- Neighbors nn = HnswGraph.search(new float[]{1, 0}, 10, 5, vectors.randomAccess(), hnsw.getGraphValues(), random());
+ NeighborQueue nn = HnswGraph.search(new float[]{1, 0}, 10, 5, vectors.randomAccess(), hnsw, random());
int sum = 0;
- Neighbors.NeighborIterator it = nn.iterator();
- for (int node = it.next(); node != NO_MORE_DOCS; node = it.next()) {
+ for (int node : nn.nodes()) {
sum += node;
}
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) = 45
assertTrue("sum(result docs)=" + sum, sum < 75);
+ for (int i = 0; i < nDoc; i++) {
+ NeighborArray neighbors = hnsw.getNeighbors(i);
+ int[] nodes = neighbors.node;
+ for (int j = 0; j < neighbors.size(); j++) {
+ // all neighbors should be valid node ids.
+ assertTrue(nodes[j] < nDoc);
+ }
+ }
+ }
+
+ public void testBoundsCheckerMax() {
+ BoundsChecker max = BoundsChecker.create(false);
+ float f = random().nextFloat() - 0.5f;
+ // any float > -MAX_VALUE is in bounds
+ assertFalse(max.check(f));
+ // f is now the bound (minus some delta)
+ max.update(f);
+ assertFalse(max.check(f)); // f is not out of bounds
+ assertFalse(max.check(f + 1)); // anything greater than f is in bounds
+ assertTrue(max.check(f - 1e-5f)); // delta is zero initially
}
- public void testMaxConnections() {
- // verify that maxConnections is observed, and that the retained arcs point to the best-scoring neighbors
- HnswGraph graph = new HnswGraph(1, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW);
- graph.connectNodes(0, 1, 0);
- assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0));
- assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
- graph.connectNodes(0, 2, 0.4f);
- assertArrayEquals(new int[]{2}, graph.getNeighborNodes(0));
- assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
- assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2));
- graph.connectNodes(2, 3, 0);
- assertArrayEquals(new int[]{2}, graph.getNeighborNodes(0));
- assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
- assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2));
- assertArrayEquals(new int[]{2}, graph.getNeighborNodes(3));
-
- graph = new HnswGraph(1, VectorValues.SearchStrategy.EUCLIDEAN_HNSW);
- graph.connectNodes(0, 1, 1);
- assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0));
- assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
- graph.connectNodes(0, 2, 2);
- assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0));
- assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
- assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2));
- graph.connectNodes(2, 3, 1);
- assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0));
- assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
- assertArrayEquals(new int[]{3}, graph.getNeighborNodes(2));
- assertArrayEquals(new int[]{2}, graph.getNeighborNodes(3));
+ public void testBoundsCheckerMin() {
+ BoundsChecker min = BoundsChecker.create(true);
+ float f = random().nextFloat() - 0.5f;
+ // any float < MAX_VALUE is in bounds
+ assertFalse(min.check(f));
+ // f is now the bound (minus some delta)
+ min.update(f);
+ assertFalse(min.check(f)); // f is not out of bounds
+ assertFalse(min.check(f - 1)); // anything less than f is in bounds
+ assertTrue(min.check(f + 1e-5f)); // delta is zero initially
+ }
+
+ public void testHnswGraphBuilderInvalid() {
+ expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0));
+ expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0));
+ expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0));
+ }
+
+ public void testDiversity() throws IOException {
+ // Some carefully checked test cases with simple 2d vectors on the unit circle:
+ MockVectorValues vectors = new MockVectorValues(VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, new float[][]{
+ unitVector2d(0.5),
+ unitVector2d(0.75),
+ unitVector2d(0.2),
+ unitVector2d(0.9),
+ unitVector2d(0.8),
+ unitVector2d(0.77),
+ });
+ // First add nodes until everybody gets a full neighbor list
+ HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 2, 10, random().nextInt());
+ // node 0 is added by the builder constructor
+ // builder.addGraphNode(vectors.vectorValue(0));
+ builder.addGraphNode(vectors.vectorValue(1));
+ builder.addGraphNode(vectors.vectorValue(2));
+ // now every node has tried to attach every other node as a neighbor, but
+ // some were excluded based on diversity check.
+ assertNeighbors(builder.hnsw, 0, 1, 2);
+ assertNeighbors(builder.hnsw, 1, 0);
+ assertNeighbors(builder.hnsw, 2, 0);
+
+ builder.addGraphNode(vectors.vectorValue(3));
+ assertNeighbors(builder.hnsw, 0, 1, 2);
+ // we added 3 here
+ assertNeighbors(builder.hnsw, 1, 0, 3);
+ assertNeighbors(builder.hnsw, 2, 0);
+ assertNeighbors(builder.hnsw, 3, 1);
+
+ // supplant an existing neighbor
+ builder.addGraphNode(vectors.vectorValue(4));
+ // 4 is the same distance from 0 that 2 is; we leave the existing node in place
+ assertNeighbors(builder.hnsw, 0, 1, 2);
+ // 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so replace it
+ assertNeighbors(builder.hnsw, 1, 0, 4);
+ assertNeighbors(builder.hnsw, 2, 0);
+ // 1 survives the diversity check
+ assertNeighbors(builder.hnsw, 3, 1, 4);
+ assertNeighbors(builder.hnsw, 4, 1, 3);
+
+ builder.addGraphNode(vectors.vectorValue(5));
+ assertNeighbors(builder.hnsw, 0, 1, 2);
+ assertNeighbors(builder.hnsw, 1, 0, 5);
+ assertNeighbors(builder.hnsw, 2, 0);
+ // even though 5 is closer, 3 is not a neighbor of 5, so no update to *its* neighbors occurs
+ assertNeighbors(builder.hnsw, 3, 1, 4);
+ assertNeighbors(builder.hnsw, 4, 3, 5);
+ assertNeighbors(builder.hnsw, 5, 1, 4);
+ }
+
+ private void assertNeighbors(HnswGraph graph, int node, int ... expected) {
+ Arrays.sort(expected);
+ NeighborArray nn = graph.getNeighbors(node);
+ int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
+ Arrays.sort(actual);
+ assertArrayEquals("expected: " + Arrays.toString(expected) + " actual: " + Arrays.toString(actual),
+ expected, actual);
+ }
+
+ public void testRandom() throws IOException {
+ int size = atLeast(100);
+ int dim = atLeast(10);
+ int topK = 5;
+ RandomVectorValues vectors = new RandomVectorValues(size, dim, random());
+ HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 10, 30, random().nextLong());
+ HnswGraph hnsw = builder.build(vectors);
+ int totalMatches = 0;
+ for (int i = 0; i < 100; i++) {
+ float[] query = randomVector(random(), dim);
+ NeighborQueue actual = HnswGraph.search(query, topK, 100, vectors, hnsw, random());
+ NeighborQueue expected = new NeighborQueue(topK, vectors.searchStrategy.reversed);
+ for (int j = 0; j < size; j++) {
+ float[] v = vectors.vectorValue(j);
+ if (v != null) {
+ expected.insertWithOverflow(j, vectors.searchStrategy.compare(query, vectors.vectorValue(j)));
+ }
+ }
+ assertEquals(topK, actual.size());
+ totalMatches += computeOverlap(actual.nodes(), expected.nodes());
+ }
+ double overlap = totalMatches / (double) (100 * topK) ;
+ System.out.println("overlap=" + overlap + " totalMatches=" + totalMatches);
+ assertTrue("overlap=" + overlap, overlap > 0.9);
}
- /** Returns vectors evenly distributed around the unit circle.
+ private int computeOverlap(int[] a, int[] b) {
+ Arrays.sort(a);
+ Arrays.sort(b);
+ int overlap = 0;
+ for (int i = 0, j = 0; i < a.length && j < b.length; ) {
+ if (a[i] == b[j]) {
+ ++ overlap;
+ ++ i;
+ ++ j;
+ } else if (a[i] > b[j]) {
+ ++j;
+ } else {
+ ++i;
+ }
+ }
+ return overlap;
+ }
+
+ /** Returns vectors evenly distributed around the upper unit semicircle.
*/
- class CircularVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
+ static class CircularVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
private final int size;
private final float[] value;
@@ -213,9 +326,7 @@ public class TestHnsw extends LuceneTestCase {
@Override
public float[] vectorValue(int ord) {
- value[0] = (float) Math.cos(Math.PI * ord / (double) size);
- value[1] = (float) Math.sin(Math.PI * ord / (double) size);
- return value;
+ return unitVector2d(ord / (double) size, value);
}
@Override
@@ -227,7 +338,16 @@ public class TestHnsw extends LuceneTestCase {
public TopDocs search(float[] target, int k, int fanout) {
return null;
}
+ }
+
+ private static float[] unitVector2d(double piRadians) {
+ return unitVector2d(piRadians, new float[2]);
+ }
+ private static float[] unitVector2d(double piRadians, float[] value) {
+ value[0] = (float) Math.cos(Math.PI * piRadians);
+ value[1] = (float) Math.sin(Math.PI * piRadians);
+ return value;
}
private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException {
@@ -259,194 +379,40 @@ public class TestHnsw extends LuceneTestCase {
}
}
- public void testNeighbors() {
- // make sure we have the sign correct
- Neighbors nn = Neighbors.create(2, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW);
- assertTrue(nn.insertWithOverflow(2, 0.5f));
- assertTrue(nn.insertWithOverflow(1, 0.2f));
- assertTrue(nn.insertWithOverflow(3, 1f));
- assertEquals(0.5f, nn.topScore(), 0);
- nn.pop();
- assertEquals(1f, nn.topScore(), 0);
- nn.pop();
-
- Neighbors fn = Neighbors.create(2, VectorValues.SearchStrategy.EUCLIDEAN_HNSW);
- assertTrue(fn.insertWithOverflow(2, 2));
- assertTrue(fn.insertWithOverflow(1, 1));
- assertFalse(fn.insertWithOverflow(3, 3));
- assertEquals(2f, fn.topScore(), 0);
- fn.pop();
- assertEquals(1f, fn.topScore(), 0);
- }
-
- private static float[] randomVector(Random random, int dim) {
- float[] vec = new float[dim];
- for (int i = 0; i < dim; i++) {
- vec[i] = random.nextFloat();
- }
- VectorUtil.l2normalize(vec);
- return vec;
- }
-
/**
* Produces random vectors and caches them for random-access.
*/
- class RandomVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
-
- private final int dimension;
- private final float[][] denseValues;
- private final float[][] values;
- private final float[] scratch;
- private final SearchStrategy searchStrategy;
-
- final int numVectors;
- final int maxDoc;
-
- private int pos = -1;
+ static class RandomVectorValues extends MockVectorValues {
RandomVectorValues(int size, int dimension, Random random) {
- this.dimension = dimension;
- values = new float[size][];
- denseValues = new float[size][];
- scratch = new float[dimension];
- int sz = 0;
- int md = -1;
- for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
- values[offset] = randomVector(random, dimension);
- denseValues[sz++] = values[offset];
- md = offset;
- }
- numVectors = sz;
- maxDoc = md;
- // get a random SearchStrategy other than NONE (0)
- searchStrategy = SearchStrategy.values()[random.nextInt(SearchStrategy.values().length - 1) + 1];
- }
-
- private RandomVectorValues(int dimension, SearchStrategy searchStrategy, float[][] denseValues, float[][] values, int size) {
- this.dimension = dimension;
- this.searchStrategy = searchStrategy;
- this.values = values;
- this.denseValues = denseValues;
- scratch = new float[dimension];
- numVectors = size;
- maxDoc = values.length - 1;
- }
-
- public RandomVectorValues copy() {
- return new RandomVectorValues(dimension, searchStrategy, denseValues, values, numVectors);
- }
-
- @Override
- public int size() {
- return numVectors;
- }
-
- @Override
- public SearchStrategy searchStrategy() {
- return searchStrategy;
- }
-
- @Override
- public int dimension() {
- return dimension;
- }
-
- @Override
- public float[] vectorValue() {
- if(random().nextBoolean()) {
- return values[pos];
- } else {
- // Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
- // This should help us catch cases of aliasing where the same VectorValues source is used twice in a
- // single computation.
- System.arraycopy(values[pos], 0, scratch, 0, dimension);
- return scratch;
- }
- }
-
- @Override
- public RandomAccessVectorValues randomAccess() {
- return copy();
- }
-
- @Override
- public float[] vectorValue(int targetOrd) {
- return denseValues[targetOrd];
+ super(SearchStrategy.values()[random.nextInt(SearchStrategy.values().length - 1) + 1],
+ createRandomVectors(size, dimension, random));
}
- @Override
- public BytesRef binaryValue(int targetOrd) {
- return null;
+ RandomVectorValues(RandomVectorValues other) {
+ super(other.searchStrategy, other.values);
}
@Override
- public TopDocs search(float[] target, int k, int fanout) {
- return null;
- }
-
- private boolean seek(int target) {
- if (target >= 0 && target < values.length && values[target] != null) {
- pos = target;
- return true;
- } else {
- return false;
- }
- }
-
- @Override
- public int docID() {
- return pos;
- }
-
- @Override
- public int nextDoc() {
- return advance(pos + 1);
+ public RandomVectorValues copy() {
+ return new RandomVectorValues(this);
}
- public int advance(int target) {
- while (++pos < values.length) {
- if (seek(pos)) {
- return pos;
- }
+ private static float[][] createRandomVectors(int size, int dimension, Random random) {
+ float[][] vectors = new float[size][];
+ for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
+ vectors[offset] = randomVector(random, dimension);
}
- return NO_MORE_DOCS;
+ return vectors;
}
-
- @Override
- public long cost() {
- return size();
- }
-
- }
-
- public void testBoundsCheckerMax() {
- BoundsChecker max = BoundsChecker.create(false);
- float f = random().nextFloat() - 0.5f;
- // any float > -MAX_VALUE is in bounds
- assertFalse(max.check(f));
- // f is now the bound (minus some delta)
- max.update(f);
- assertFalse(max.check(f)); // f is not out of bounds
- assertFalse(max.check(f + 1)); // anything greater than f is in bounds
- assertTrue(max.check(f - 1e-5f)); // delta is zero initially
- }
-
- public void testBoundsCheckerMin() {
- BoundsChecker min = BoundsChecker.create(true);
- float f = random().nextFloat() - 0.5f;
- // any float < MAX_VALUE is in bounds
- assertFalse(min.check(f));
- // f is now the bound (minus some delta)
- min.update(f);
- assertFalse(min.check(f)); // f is not out of bounds
- assertFalse(min.check(f - 1)); // anything less than f is in bounds
- assertTrue(min.check(f + 1e-5f)); // delta is zero initially
}
- public void testHnswGraphBuilderInvalid() {
- expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0));
- expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0));
- expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0));
+ private static float[] randomVector(Random random, int dim) {
+ float[] vec = new float[dim];
+ for (int i = 0; i < dim; i++) {
+ vec[i] = random.nextFloat();
+ }
+ VectorUtil.l2normalize(vec);
+ return vec;
}
-
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighbors.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighbors.java
new file mode 100644
index 0000000..eda0763
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighbors.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.hnsw;
+
+import org.apache.lucene.util.LuceneTestCase;
+
+public class TestNeighbors extends LuceneTestCase {
+
+ public void testNeighborsProduct() {
+ // make sure we have the sign correct
+ NeighborQueue nn = new NeighborQueue(2, false);
+ assertTrue(nn.insertWithOverflow(2, 0.5f));
+ assertTrue(nn.insertWithOverflow(1, 0.2f));
+ assertTrue(nn.insertWithOverflow(3, 1f));
+ assertEquals(0.5f, nn.topScore(), 0);
+ nn.pop();
+ assertEquals(1f, nn.topScore(), 0);
+ nn.pop();
+ }
+
+ public void testNeighborsMaxHeap() {
+ NeighborQueue nn = new NeighborQueue(2, true);
+ assertTrue(nn.insertWithOverflow(2, 2));
+ assertTrue(nn.insertWithOverflow(1, 1));
+ assertFalse(nn.insertWithOverflow(3, 3));
+ assertEquals(2f, nn.topScore(), 0);
+ nn.pop();
+ assertEquals(1f, nn.topScore(), 0);
+ }
+
+ public void testTopMaxHeap() {
+ NeighborQueue nn = new NeighborQueue(2, true);
+ nn.add(1, 2);
+ nn.add(2, 1);
+ // lower scores are better; highest score on top
+ assertEquals(2, nn.topScore(), 0);
+ assertEquals(1, nn.topNode());
+ }
+
+ public void testTopMinHeap() {
+ NeighborQueue nn = new NeighborQueue(2, false);
+ nn.add(1, 0.5f);
+ nn.add(2, -0.5f);
+ // higher scores are better; lowest score on top
+ assertEquals(-0.5f, nn.topScore(), 0);
+ assertEquals(2, nn.topNode());
+ }
+
+ public void testVisitedCount() {
+ NeighborQueue nn = new NeighborQueue(2, false);
+ nn.setVisitedCount(100);
+ assertEquals(100, nn.visitedCount());
+ }
+
+ public void testMaxSizeQueue() {
+ NeighborQueue nn = new NeighborQueue(2, false);
+ nn.add(1, 1);
+ nn.add(2, 2);
+ assertEquals(2, nn.size());
+ assertEquals(1, nn.topNode());
+
+ // insertWithOverflow does not extend the queue
+ nn.insertWithOverflow(3, 3);
+ assertEquals(2, nn.size());
+ assertEquals(2, nn.topNode());
+
+ // add does extend the queue beyond maxSize
+ nn.add(4, 1);
+ assertEquals(3, nn.size());
+ }
+
+ public void testUnboundedQueue() {
+ NeighborQueue nn = new NeighborQueue(1, true);
+ float maxScore = -2;
+ int maxNode = -1;
+ for (int i = 0; i < 256; i++) {
+ // initial size is 32
+ float score = random().nextFloat();
+ if (score > maxScore) {
+ maxScore = score;
+ maxNode = i;
+ }
+ nn.add(i, score);
+ }
+ assertEquals(maxScore, nn.topScore(), 0);
+ assertEquals(maxNode, nn.topNode());
+ }
+
+ public void testInvalidArguments() {
+ expectThrows(IllegalArgumentException.class, () -> new NeighborQueue(0, false));
+ }
+
+ public void testToString() {
+ assertEquals("Neighbors[0]", new NeighborQueue(2, false).toString());
+ //assertEquals("NeighborArray[0]", new NeighborArray(2, VectorValues.SearchStrategy.NONE).toString());
+ }
+
+}