You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by so...@apache.org on 2020/12/09 20:41:30 UTC
[lucene-solr] branch master updated: LUCENE-9626 represent HNSW
graph neighbors using primitive arrays (#2108)
This is an automated email from the ASF dual-hosted git repository.
sokolov pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git
The following commit(s) were added to refs/heads/master by this push:
new af3e122 LUCENE-9626 represent HNSW graph neighbors using primitive arrays (#2108)
af3e122 is described below
commit af3e12265fed3a29af30719a5b7c46424f737af4
Author: Michael Sokolov <so...@falutin.net>
AuthorDate: Wed Dec 9 15:41:10 2020 -0500
LUCENE-9626 represent HNSW graph neighbors using primitive arrays (#2108)
* also adds LongHeap, a primitive int priority queue
---
lucene/NOTICE.txt | 3 +
.../codecs/lucene90/Lucene90VectorReader.java | 12 +-
.../codecs/lucene90/Lucene90VectorWriter.java | 2 +-
.../org/apache/lucene/document/VectorField.java | 3 +-
.../src/java/org/apache/lucene/util/LongHeap.java | 247 ++++++++++++++++++++
.../java/org/apache/lucene/util/NumericUtils.java | 1 -
.../java/org/apache/lucene/util/VectorUtil.java | 36 ++-
.../org/apache/lucene/util/hnsw/HnswGraph.java | 102 ++++-----
.../apache/lucene/util/hnsw/HnswGraphBuilder.java | 28 ++-
.../java/org/apache/lucene/util/hnsw/Neighbor.java | 70 ------
.../org/apache/lucene/util/hnsw/Neighbors.java | 115 ++++++----
.../test/org/apache/lucene/index/TestKnnGraph.java | 6 +-
.../org/apache/lucene/index/TestVectorValues.java | 22 +-
.../test/org/apache/lucene/util/TestLongHeap.java | 248 +++++++++++++++++++++
.../org/apache/lucene/util/TestNumericUtils.java | 1 +
.../org/apache/lucene/util/TestVectorUtil.java | 32 ++-
.../apache/lucene/util/hnsw/KnnGraphTester.java | 14 +-
.../test/org/apache/lucene/util/hnsw/TestHnsw.java | 109 ++++-----
18 files changed, 781 insertions(+), 270 deletions(-)
diff --git a/lucene/NOTICE.txt b/lucene/NOTICE.txt
index 7dea1da..daccd40 100644
--- a/lucene/NOTICE.txt
+++ b/lucene/NOTICE.txt
@@ -207,3 +207,6 @@ This software includes a binary and/or source version of data from
which can be obtained from
https://bitbucket.org/eunjeon/mecab-ko-dic/downloads/mecab-ko-dic-2.0.3-20170922.tar.gz
+
+The floating point precision conversion in NumericUtils.Float16Converter is derived from work by
+Jeroen van der Zijp, granted for use under the Apache license.
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java
index 674959f..79d4dd0 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java
@@ -45,7 +45,6 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
-import org.apache.lucene.util.hnsw.Neighbor;
import org.apache.lucene.util.hnsw.Neighbors;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@@ -366,14 +365,13 @@ public final class Lucene90VectorReader extends VectorReader {
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), topK)];
boolean reversed = searchStrategy().reversed;
while (results.size() > 0) {
- Neighbor n = results.pop();
- float score;
+ int node = results.topNode();
+ float score = results.topScore();
+ results.pop();
if (reversed) {
- score = (float) Math.exp(- n.score() / vector.length);
- } else {
- score = n.score();
+ score = (float) Math.exp(-score / vector.length);
}
- scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[n.node()], score);
+ scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
}
// always return >= the case where we can assert == is only when there are fewer than topK vectors in the index
return new TopDocs(new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), scoreDocs);
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java
index 71d103b..f1a2da9 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorWriter.java
@@ -142,7 +142,7 @@ public final class Lucene90VectorWriter extends VectorWriter {
for (int ord = 0; ord < count; ord++) {
// write graph
offsets[ord] = graphData.getFilePointer() - graphDataOffset;
- int[] arcs = graph.getNeighbors(ord);
+ int[] arcs = graph.getNeighborNodes(ord);
Arrays.sort(arcs);
graphData.writeInt(arcs.length);
int lastArc = -1; // to make the assertion work?
diff --git a/lucene/core/src/java/org/apache/lucene/document/VectorField.java b/lucene/core/src/java/org/apache/lucene/document/VectorField.java
index eff6e12..5e580ef 100644
--- a/lucene/core/src/java/org/apache/lucene/document/VectorField.java
+++ b/lucene/core/src/java/org/apache/lucene/document/VectorField.java
@@ -52,7 +52,8 @@ public class VectorField extends Field {
}
/** Creates a numeric vector field. Fields are single-valued: each document has either one value
- * or no value. Vectors of a single field share the same dimension and search strategy.
+ * or no value. Vectors of a single field share the same dimension and search strategy. Note that some strategies
+ * (notably dot-product) require values to be unit-length, which can be enforced using VectorUtil.l2Normalize(float[]).
*
* @param name field name
* @param vector value
diff --git a/lucene/core/src/java/org/apache/lucene/util/LongHeap.java b/lucene/core/src/java/org/apache/lucene/util/LongHeap.java
new file mode 100644
index 0000000..29a8f83
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/LongHeap.java
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util;
+
+/**
+ * A heap that stores longs; a primitive priority queue that like all priority queues maintains a
+ * partial ordering of its elements such that the least element can always be found in constant
+ * time. Put()'s and pop()'s require log(size). This heap may be bounded by constructing with a
+ * finite maxSize, or enabled to grow dynamically by passing the constant UNBOUNDED for the maxSize.
+ * The heap may be either a min heap, in which case the least element is the smallest integer, or a
+ * max heap, when it is the largest, depending on the Order parameter.
+ *
+ * <b>NOTE</b>: Iteration order is not specified.
+ *
+ * @lucene.internal
+ */
+public abstract class LongHeap {
+
+ /**
+ * Used to specify the ordering of the heap. A min-heap provides access to the minimum element in
+ * constant time, and when bounded, retains the maximum <code>maxSize</code> elements. A max-heap
+ * conversely provides access to the maximum element in constant time, and when bounded retains
+ * the minimum <code>maxSize</code> elements.
+ */
+ public enum Order {
+ MIN, MAX
+ };
+
+ private static final int UNBOUNDED = -1;
+ private final int maxSize;
+
+ private long[] heap;
+ private int size = 0;
+
+ /**
+ * Create an empty priority queue of the configured size.
+ * @param maxSize the maximum size of the heap, or if negative, the initial size of an unbounded heap
+ */
+ LongHeap(int maxSize) {
+ final int heapSize;
+ if (maxSize < 0) {
+ // initial size; this may grow
+ heapSize = -maxSize;
+ this.maxSize = UNBOUNDED;
+ } else {
+ if ((maxSize < 1) || (maxSize >= ArrayUtil.MAX_ARRAY_LENGTH)) {
+ // Throw exception to prevent confusing OOME:
+ throw new IllegalArgumentException("maxSize must be UNBOUNDED(-1) or > 0 and < " + (ArrayUtil.MAX_ARRAY_LENGTH) + "; got: " + maxSize);
+ }
+ // NOTE: we add +1 because all access to heap is 1-based not 0-based. heap[0] is unused.
+ heapSize = maxSize + 1;
+ this.maxSize = maxSize;
+ }
+ this.heap = new long[heapSize];
+ }
+
+ public static LongHeap create(Order order, int maxSize) {
+ // TODO: override push() for unbounded queue
+ if (order == Order.MIN) {
+ return new LongHeap(maxSize) {
+ @Override
+ public boolean lessThan(long a, long b) {
+ return a < b;
+ }
+ };
+ } else {
+ return new LongHeap(maxSize) {
+ @Override
+ public boolean lessThan(long a, long b) {
+ return a > b;
+ }
+ };
+ }
+ }
+
+ /** Determines the ordering of objects in this priority queue. Subclasses must define this one
+ * method.
+ * @return <code>true</code> iff parameter <code>a</code> is less than parameter <code>b</code>.
+ */
+ public abstract boolean lessThan(long a, long b);
+
+ /**
+ * Adds a value in log(size) time. If one tries to add more values than maxSize from initialize an
+ * {@link ArrayIndexOutOfBoundsException} is thrown, unless maxSize is {@link #UNBOUNDED}.
+ *
+ * @return the new 'top' element in the queue.
+ */
+ public final long push(long element) {
+ size++;
+ if (maxSize == UNBOUNDED && size == heap.length) {
+ heap = ArrayUtil.grow(heap, (size * 3 + 1) / 2);
+ }
+ heap[size] = element;
+ upHeap(size);
+ return heap[1];
+ }
+
+ /**
+ * Adds a value to an IntHeap in log(size) time. if the number of values would exceed the heap's
+ * maxSize, the least value is discarded.
+ * @return whether the value was added (unless the heap is full, or the new value is less than the top value)
+ */
+ public boolean insertWithOverflow(long value) {
+ if (size < maxSize || maxSize == UNBOUNDED) {
+ push(value);
+ return true;
+ } else if (size > 0 && !lessThan(value, heap[1])) {
+ updateTop(value);
+ return true;
+ }
+ return false;
+ }
+
+ /** Returns the least element of the IntHeap in constant time. It is up to the caller to verify
+ * that the heap is not empty; no checking is done, and if no elements have been added, 0 is
+ * returned.
+ */
+ public final long top() {
+ return heap[1];
+ }
+
+ /** Removes and returns the least element of the PriorityQueue in log(size) time.
+ * @throws IllegalStateException if the IntHeap is empty.
+ */
+ public final long pop() {
+ if (size > 0) {
+ long result = heap[1]; // save first value
+ heap[1] = heap[size]; // move last to first
+ size--;
+ downHeap(1); // adjust heap
+ return result;
+ } else {
+ throw new IllegalStateException("The heap is empty");
+ }
+ }
+
+ /**
+ * Replace the top of the pq with {@code newTop}. Should be called when the top value
+ * changes. Still log(n) worst case, but it's at least twice as fast to
+ *
+ * <pre class="prettyprint">
+ * pq.updateTop(value);
+ * </pre>
+ *
+ * instead of
+ *
+ * <pre class="prettyprint">
+ * pq.pop();
+ * pq.push(value);
+ * </pre>
+ *
+ * Calling this method on an empty LongHeap has no visible effect.
+ *
+ * @param value the new element that is less than the current top.
+ * @return the new 'top' element after shuffling the heap.
+ */
+ public final long updateTop(long value) {
+ heap[1] = value;
+ downHeap(1);
+ return heap[1];
+ }
+
+ /** Returns the number of elements currently stored in the PriorityQueue. */
+ public final int size() {
+ return size;
+ }
+
+ /** Removes all entries from the PriorityQueue. */
+ public final void clear() {
+ size = 0;
+ }
+
+ private final void upHeap(int origPos) {
+ int i = origPos;
+ long value = heap[i]; // save bottom value
+ int j = i >>> 1;
+ while (j > 0 && lessThan(value, heap[j])) {
+ heap[i] = heap[j]; // shift parents down
+ i = j;
+ j = j >>> 1;
+ }
+ heap[i] = value; // install saved value
+ }
+
+ private final void downHeap(int i) {
+ long value = heap[i]; // save top value
+ int j = i << 1; // find smaller child
+ int k = j + 1;
+ if (k <= size && lessThan(heap[k], heap[j])) {
+ j = k;
+ }
+ while (j <= size && lessThan(heap[j], value)) {
+ heap[i] = heap[j]; // shift up child
+ i = j;
+ j = i << 1;
+ k = j + 1;
+ if (k <= size && lessThan(heap[k], heap[j])) {
+ j = k;
+ }
+ }
+ heap[i] = value; // install saved value
+ }
+
+ public LongIterator iterator() {
+ return new LongIterator();
+ }
+
+ /**
+ * Iterator over the contents of the heap, returning successive ints.
+ */
+ public class LongIterator {
+ int i = 1;
+
+ public boolean hasNext() {
+ return i <= size;
+ }
+
+ public long next() {
+ if (hasNext() == false) {
+ throw new IllegalStateException("attempt to iterate beyond maximum element, size=" + size);
+ }
+ return heap[i++];
+ }
+ }
+
+ /** This method returns the internal heap array.
+ * @lucene.internal
+ */
+ protected final long[] getHeapArray() {
+ return heap;
+ }
+
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/NumericUtils.java b/lucene/core/src/java/org/apache/lucene/util/NumericUtils.java
index f6c15c0..6ed8c7d 100644
--- a/lucene/core/src/java/org/apache/lucene/util/NumericUtils.java
+++ b/lucene/core/src/java/org/apache/lucene/util/NumericUtils.java
@@ -89,7 +89,6 @@ public final class NumericUtils {
return bits ^ (bits >> 31) & 0x7fffffff;
}
-
/** Result = a - b, where a >= b, else {@code IllegalArgumentException} is thrown. */
public static void subtract(int bytesPerDim, int dim, byte[] a, byte[] b, byte[] result) {
int start = dim * bytesPerDim;
diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
index e2f8fea..5c6346d 100644
--- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
@@ -25,7 +25,14 @@ public final class VectorUtil {
private VectorUtil() {
}
+ /**
+ * Returns the vector dot product of the two vectors. IllegalArgumentException is thrown if the vectors'
+ * dimensions differ.
+ */
public static float dotProduct(float[] a, float[] b) {
+ if (a.length != b.length) {
+ throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
+ }
float res = 0f;
/*
* If length of vector is larger than 8, we use unrolled dot product to accelerate the
@@ -60,8 +67,15 @@ public final class VectorUtil {
return res;
}
+ /**
+ * Returns the sum of squared differences of the two vectors. IllegalArgumentException is thrown if the vectors'
+ * dimensions differ.
+ */
+
public static float squareDistance(float[] v1, float[] v2) {
- assert v1.length == v2.length;
+ if (v1.length != v2.length) {
+ throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + v2.length);
+ }
float squareSum = 0.0f;
int dim = v1.length;
for (int i = 0; i < dim; i++) {
@@ -71,4 +85,24 @@ public final class VectorUtil {
return squareSum;
}
+ /**
+ * Modifies the argument to be unit length, dividing by its l2-norm.
+ * IllegalArgumentException is thrown for zero vectors.
+ */
+ public static void l2normalize(float[] v) {
+ double squareSum = 0.0f;
+ int dim = v.length;
+ for (float x : v) {
+ squareSum += x * x;
+ }
+ if (squareSum == 0) {
+ throw new IllegalArgumentException("Cannot normalize a zero-length vector");
+ }
+ double length = Math.sqrt(squareSum);
+ for (int i = 0; i < dim; i++) {
+ v[i] /= length;
+ }
+ }
+
+
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
index 37933a8..97c9175 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
@@ -19,12 +19,10 @@ package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
-import java.util.TreeSet;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
@@ -67,7 +65,7 @@ public final class HnswGraph {
HnswGraph(int maxConn, VectorValues.SearchStrategy searchStrategy) {
graph = new ArrayList<>();
- graph.add(Neighbors.create(maxConn, searchStrategy.reversed));
+ graph.add(Neighbors.create(maxConn, searchStrategy));
this.maxConn = maxConn;
this.searchStrategy = searchStrategy;
}
@@ -85,39 +83,38 @@ public final class HnswGraph {
public static Neighbors search(float[] query, int topK, int numSeed, RandomAccessVectorValues vectors, KnnGraphValues graphValues,
Random random) throws IOException {
VectorValues.SearchStrategy searchStrategy = vectors.searchStrategy();
- // TODO: use unbounded priority queue
- TreeSet<Neighbor> candidates;
- if (searchStrategy.reversed) {
- candidates = new TreeSet<>(Comparator.reverseOrder());
- } else {
- candidates = new TreeSet<>();
- }
+
+ Neighbors results = Neighbors.create(topK, searchStrategy);
+ Neighbors candidates = Neighbors.createReversed(-numSeed, searchStrategy);
+ // set of ordinals that have been visited by search on this layer, used to avoid backtracking
+ Set<Integer> visited = new HashSet<>();
+
int size = vectors.size();
- for (int i = 0; i < numSeed && i < size; i++) {
+ int boundedNumSeed = Math.min(numSeed, 2 * size);
+ for (int i = 0; i < boundedNumSeed; i++) {
int entryPoint = random.nextInt(size);
- candidates.add(new Neighbor(entryPoint, searchStrategy.compare(query, vectors.vectorValue(entryPoint))));
+ if (visited.add(entryPoint)) {
+ results.insertWithOverflow(entryPoint, searchStrategy.compare(query, vectors.vectorValue(entryPoint)));
+ }
}
- // set of ordinals that have been visited by search on this layer, used to avoid backtracking
- Set<Integer> visited = new HashSet<>();
- // TODO: use PriorityQueue's sentinel optimization?
- Neighbors results = Neighbors.create(topK, searchStrategy.reversed);
- for (Neighbor c : candidates) {
- visited.add(c.node());
- results.insertWithOverflow(c);
+ Neighbors.NeighborIterator it = results.iterator();
+ for (int nbr = it.next(); nbr != NO_MORE_DOCS; nbr = it.next()) {
+ candidates.add(nbr, it.score());
}
// Set the bound to the worst current result and below reject any newly-generated candidates failing
// to exceed this bound
BoundsChecker bound = BoundsChecker.create(searchStrategy.reversed);
- bound.bound = results.top().score();
+ bound.bound = results.topScore();
while (candidates.size() > 0) {
// get the best candidate (closest or best scoring)
- Neighbor c = candidates.pollLast();
+ float topCandidateScore = candidates.topScore();
if (results.size() >= topK) {
- if (bound.check(c.score())) {
+ if (bound.check(topCandidateScore)) {
break;
}
}
- graphValues.seek(c.node());
+ int topCandidateNode = candidates.pop();
+ graphValues.seek(topCandidateNode);
int friendOrd;
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
if (visited.contains(friendOrd)) {
@@ -125,11 +122,9 @@ public final class HnswGraph {
}
visited.add(friendOrd);
float score = searchStrategy.compare(query, vectors.vectorValue(friendOrd));
- if (results.size() < topK || bound.check(score) == false) {
- Neighbor n = new Neighbor(friendOrd, score);
- candidates.add(n);
- results.insertWithOverflow(n);
- bound.bound = results.top().score();
+ if (results.insertWithOverflow(friendOrd, score)) {
+ candidates.add(friendOrd, score);
+ bound.bound = results.topScore();
}
}
}
@@ -138,18 +133,21 @@ public final class HnswGraph {
}
/**
- * Returns the nodes connected to the given node by its outgoing neighborNodes in an unpredictable order. Each node inserted
- * by HnswGraphBuilder corresponds to a vector, and the node is the vector's ordinal.
- * @param node the node whose friends are returned
+ * Returns the {@link Neighbors} connected to the given node.
+ * @param node the node whose neighbors are returned
*/
- public int[] getNeighbors(int node) {
+ public Neighbors getNeighbors(int node) {
+ return graph.get(node);
+ }
+
+ public int[] getNeighborNodes(int node) {
Neighbors neighbors = graph.get(node);
- int[] result = new int[neighbors.size()];
- int i = 0;
- for (Neighbor n : neighbors) {
- result[i++] = n.node();
+ int[] nodes = new int[neighbors.size()];
+ Neighbors.NeighborIterator it = neighbors.iterator();
+ for (int neighbor = it.next(), i = 0; neighbor != NO_MORE_DOCS; neighbor = it.next()) {
+ nodes[i++] = neighbor;
}
- return result;
+ return nodes;
}
/** Connects two nodes symmetrically, limiting the maximum number of connections from either node.
@@ -175,24 +173,12 @@ public final class HnswGraph {
boolean connect(int node1, int node2, float score) {
//System.out.println(" HnswGraph.connect " + node1 + " -> " + node2);
assert node1 >= 0 && node2 >= 0;
- Neighbors nn = graph.get(node1);
- assert nn != null;
- if (nn.size() == maxConn) {
- Neighbor top = nn.top();
- if (score < top.score() == nn.reversed()) {
- top.update(node2, score);
- nn.updateTop();
- return true;
- }
- } else {
- nn.add(new Neighbor(node2, score));
- return true;
- }
- return false;
+ return graph.get(node1)
+ .insertWithOverflow(node2, score);
}
int addNode() {
- graph.add(Neighbors.create(maxConn, searchStrategy.reversed));
+ graph.add(Neighbors.create(maxConn, searchStrategy));
return graph.size() - 1;
}
@@ -201,23 +187,17 @@ public final class HnswGraph {
*/
private class HnswGraphValues extends KnnGraphValues {
- private int arcUpTo;
- private int[] neighborNodes;
+ private Neighbors.NeighborIterator it;
@Override
public void seek(int targetNode) {
- arcUpTo = 0;
- neighborNodes = HnswGraph.this.getNeighbors(targetNode);
+ it = HnswGraph.this.getNeighbors(targetNode).iterator();
}
@Override
public int nextNeighbor() {
- if (arcUpTo >= neighborNodes.length) {
- return NO_MORE_DOCS;
- }
- return neighborNodes[arcUpTo++];
+ return it.next();
}
-
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index d116179..e225c9b 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -18,6 +18,7 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
+import java.util.Locale;
import java.util.Random;
import org.apache.lucene.index.KnnGraphValues;
@@ -25,6 +26,9 @@ import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.InfoStream;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the hyperparameters.
@@ -33,6 +37,7 @@ public final class HnswGraphBuilder {
// default random seed for level generation
private static final long DEFAULT_RAND_SEED = System.currentTimeMillis();
+ public static final String HNSW_COMPONENT = "HNSW";
// expose for testing.
public static long randSeed = DEFAULT_RAND_SEED;
@@ -49,6 +54,10 @@ public final class HnswGraphBuilder {
private final int maxConn;
private final int beamWidth;
+ // TODO: how to pass this in?
+ InfoStream infoStream = InfoStream.getDefault();
+ // InfoStream infoStream = new PrintStreamInfoStream(System.out);
+
private final BoundedVectorValues boundedVectors;
private final VectorValues.SearchStrategy searchStrategy;
private final HnswGraph hnsw;
@@ -86,8 +95,17 @@ public final class HnswGraphBuilder {
if (vectors == boundedVectors.raDelegate) {
throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
}
+ long start = System.nanoTime(), t = start;
for (int node = 1; node < vectors.size(); node++) {
insert(vectors.vectorValue(node));
+ if (node % 10000 == 0) {
+ if (infoStream.isEnabled(HNSW_COMPONENT)) {
+ long now = System.nanoTime();
+ infoStream.message(HNSW_COMPONENT,
+ String.format(Locale.ROOT, "HNSW built %d in %d/%d ms", node, ((now - t) / 1_000_000), ((now - start) / 1_000_000)));
+ t = now;
+ }
+ }
}
return hnsw;
}
@@ -137,10 +155,12 @@ public final class HnswGraphBuilder {
private void addNearestNeighbors(int newNode, Neighbors neighbors) {
// connect the nearest neighbors, relying on the graph's Neighbors' priority queues to drop off distant neighbors
- for (Neighbor neighbor : neighbors) {
- if (hnsw.connect(newNode, neighbor.node(), neighbor.score())) {
- hnsw.connect(neighbor.node(), newNode, neighbor.score());
- }
+ Neighbors.NeighborIterator it = neighbors.iterator();
+ for (int node = it.next(); node != NO_MORE_DOCS; node = it.next()) {
+ float score = it.score();
+ if (hnsw.connect(newNode, node, score)) {
+ hnsw.connect(node, newNode, score);
+ }
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbor.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbor.java
deleted file mode 100644
index 01cf231..0000000
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbor.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.lucene.util.hnsw;
-
-/** A neighbor node in the HNSW graph; holds the node ordinal and its distance score. */
-public class Neighbor implements Comparable<Neighbor> {
-
- private int node;
-
- private float score;
-
- public Neighbor(int node, float score) {
- this.node = node;
- this.score = score;
- }
-
- public int node() {
- return node;
- }
-
- public float score() {
- return score;
- }
-
- void update(int node, float score) {
- this.node = node;
- this.score = score;
- }
-
- @Override
- public int compareTo(Neighbor o) {
- if (score == o.score) {
- return o.node - node;
- } else {
- assert node != o.node : "attempt to add the same node " + node + " twice with different scores: " + score + " != " + o.score;
- return Float.compare(score, o.score);
- }
- }
-
- @Override
- public boolean equals(Object other) {
- return other instanceof Neighbor
- && ((Neighbor) other).node == node;
- }
-
- @Override
- public int hashCode() {
- return 39 + 61 * node;
- }
-
- @Override
- public String toString() {
- return "(" + node + ", " + score + ")";
- }
-}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java
index 6ca761b..d193b38 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/Neighbors.java
@@ -17,60 +17,70 @@
package org.apache.lucene.util.hnsw;
-import org.apache.lucene.util.PriorityQueue;
+import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.util.LongHeap;
+import org.apache.lucene.util.NumericUtils;
-/** Neighbors queue. */
-public abstract class Neighbors extends PriorityQueue<Neighbor> {
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
- public static Neighbors create(int maxSize, boolean reversed) {
- if (reversed) {
- return new ReverseNeighbors(maxSize);
- } else {
- return new ForwardNeighbors(maxSize);
- }
+/** Neighbors encodes the neighbors of a node in the HNSW graph. */
+public class Neighbors {
+
+ private static final int INITIAL_SIZE = 128;
+
+ public static Neighbors create(int maxSize, VectorValues.SearchStrategy searchStrategy) {
+ return new Neighbors(maxSize, searchStrategy, searchStrategy.reversed);
}
- public abstract boolean reversed();
+ public static Neighbors createReversed(int maxSize, VectorValues.SearchStrategy searchStrategy) {
+ return new Neighbors(maxSize, searchStrategy, !searchStrategy.reversed);
+ }
+
+ private final LongHeap heap;
+ private final VectorValues.SearchStrategy searchStrategy;
// Used to track the number of neighbors visited during a single graph traversal
private int visitedCount;
- private Neighbors(int maxSize) {
- super(maxSize);
+ private Neighbors(int maxSize, VectorValues.SearchStrategy searchStrategy, boolean reversed) {
+ this.searchStrategy = searchStrategy;
+ if (reversed) {
+ heap = LongHeap.create(LongHeap.Order.MAX, maxSize);
+ } else {
+ heap = LongHeap.create(LongHeap.Order.MIN, maxSize);
+ }
}
- private static class ForwardNeighbors extends Neighbors {
- ForwardNeighbors(int maxSize) {
- super(maxSize);
- }
+ public int size() {
+ return heap.size();
+ }
- @Override
- protected boolean lessThan(Neighbor a, Neighbor b) {
- if (a.score() == b.score()) {
- return a.node() > b.node();
- }
- return a.score() < b.score();
- }
+ public boolean reversed() {
+ return searchStrategy.reversed;
+ }
- @Override
- public boolean reversed() { return false; }
+ public void add(int newNode, float newScore) {
+ heap.push(encode(newNode, newScore));
}
- private static class ReverseNeighbors extends Neighbors {
- ReverseNeighbors(int maxSize) {
- super(maxSize);
- }
+ public boolean insertWithOverflow(int newNode, float newScore) {
+ return heap.insertWithOverflow(encode(newNode, newScore));
+ }
- @Override
- protected boolean lessThan(Neighbor a, Neighbor b) {
- if (a.score() == b.score()) {
- return a.node() > b.node();
- }
- return b.score() < a.score();
- }
+ private long encode(int node, float score) {
+ return (((long) NumericUtils.floatToSortableInt(score)) << 32) | node;
+ }
+
+ public int pop() {
+ return (int) heap.pop();
+ }
+
+ public int topNode() {
+ return (int) heap.top();
+ }
- @Override
- public boolean reversed() { return true; }
+ public float topScore() {
+ return NumericUtils.sortableIntToFloat((int) (heap.top() >> 32));
}
void setVisitedCount(int visitedCount) {
@@ -81,13 +91,32 @@ public abstract class Neighbors extends PriorityQueue<Neighbor> {
return visitedCount;
}
+ public NeighborIterator iterator() {
+ return new NeighborIterator();
+ }
+
+ class NeighborIterator {
+ private long value;
+ private final LongHeap.LongIterator heapIterator = heap.iterator();
+
+ /** Return the next node */
+ public int next() {
+ if (heapIterator.hasNext()) {
+ value = heapIterator.next();
+ return (int) value;
+ }
+ return NO_MORE_DOCS;
+ }
+
+ /** Return the score corresponding to the last node returned by next() */
+ public float score() {
+ return NumericUtils.sortableIntToFloat((int) (value >> 32));
+ }
+ }
+
@Override
public String toString() {
- StringBuilder sb = new StringBuilder();
- sb.append("Neighbors=[");
- this.iterator().forEachRemaining(sb::append);
- sb.append("]");
- return sb.toString();
+ return "Neighbors[" + heap.size() + "]";
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index b6b43fa..35b4dfa 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -28,6 +28,7 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.junit.After;
import org.junit.Before;
@@ -112,6 +113,7 @@ public class TestKnnGraph extends LuceneTestCase {
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
+ VectorUtil.l2normalize(values[i]);
}
add(iw, i, values[i]);
if (random().nextInt(10) == 3) {
@@ -185,8 +187,8 @@ public class TestKnnGraph extends LuceneTestCase {
// For this small graph the "search" is exhaustive, so this mostly tests the APIs, the orientation of the
// various priority queues, the scoring function, but not so much the approximate KNN search algo
assertGraphSearch(new int[]{0, 15, 3, 18, 5}, new float[]{0f, 0.1f}, dr);
- // test tiebreaking by docid
- assertGraphSearch(new int[]{11, 1, 8, 14, 21}, new float[]{2, 2}, dr);
+ // Tiebreaking by docid must be done after VectorValues.search.
+ // assertGraphSearch(new int[]{11, 1, 8, 14, 21}, new float[]{2, 2}, dr);
assertGraphSearch(new int[]{15, 18, 0, 3, 5},new float[]{0.3f, 0.8f}, dr);
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestVectorValues.java b/lucene/core/src/test/org/apache/lucene/index/TestVectorValues.java
index a9cd946..4d08be5 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestVectorValues.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestVectorValues.java
@@ -35,6 +35,7 @@ import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
+import org.apache.lucene.util.VectorUtil;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@@ -574,29 +575,29 @@ public class TestVectorValues extends LuceneTestCase {
String fieldName = "field";
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, iwc)) {
- add(iw, fieldName, 1, 1, new float[]{1});
- add(iw, fieldName, 4, 4, new float[]{4});
+ add(iw, fieldName, 1, 1, new float[]{-1, 0});
+ add(iw, fieldName, 4, 4, new float[]{0, 1});
add(iw, fieldName, 3, 3, null);
- add(iw, fieldName, 2, 2, new float[]{2});
+ add(iw, fieldName, 2, 2, new float[]{1, 0});
iw.forceMerge(1);
try (IndexReader reader = iw.getReader()) {
LeafReader leaf = getOnlyLeafReader(reader);
VectorValues vectorValues = leaf.getVectorValues(fieldName);
- assertEquals(1, vectorValues.dimension());
+ assertEquals(2, vectorValues.dimension());
assertEquals(3, vectorValues.size());
assertEquals("1", leaf.document(vectorValues.nextDoc()).get("id"));
- assertEquals(1f, vectorValues.vectorValue()[0], 0);
+ assertEquals(-1f, vectorValues.vectorValue()[0], 0);
assertEquals("2", leaf.document(vectorValues.nextDoc()).get("id"));
- assertEquals(2f, vectorValues.vectorValue()[0], 0);
+ assertEquals(1, vectorValues.vectorValue()[0], 0);
assertEquals("4", leaf.document(vectorValues.nextDoc()).get("id"));
- assertEquals(4f, vectorValues.vectorValue()[0], 0);
+ assertEquals(0, vectorValues.vectorValue()[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
RandomAccessVectorValues ra = ((RandomAccessVectorValuesProducer) vectorValues).randomAccess();
- assertEquals(1f, ra.vectorValue(0)[0], 0);
- assertEquals(2f, ra.vectorValue(1)[0], 0);
- assertEquals(4f, ra.vectorValue(2)[0], 0);
+ assertEquals(-1f, ra.vectorValue(0)[0], 0);
+ assertEquals(1f, ra.vectorValue(1)[0], 0);
+ assertEquals(0f, ra.vectorValue(2)[0], 0);
}
}
}
@@ -735,6 +736,7 @@ public class TestVectorValues extends LuceneTestCase {
for (int i = 0; i < dim; i++) {
v[i] = random().nextFloat();
}
+ VectorUtil.l2normalize(v);
return v;
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java b/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java
new file mode 100644
index 0000000..9f34811
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/TestLongHeap.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util;
+
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+import static org.apache.lucene.util.LongHeap.Order.MAX;
+import static org.apache.lucene.util.LongHeap.Order.MIN;
+
+public class TestLongHeap extends LuceneTestCase {
+
+ private static class AssertingLongHeap extends LongHeap {
+ public AssertingLongHeap(int count) {
+ super(count);
+ }
+
+ @Override
+ public boolean lessThan(long a, long b) {
+ return (a < b);
+ }
+
+ protected final void checkValidity() {
+ long[] heapArray = getHeapArray();
+ for (int i = 1; i <= size(); i++) {
+ int parent = i >>> 1;
+ if (parent > 1) {
+ if (lessThan(heapArray[parent], heapArray[i]) == false) {
+ assertEquals(heapArray[parent], heapArray[i]);
+ }
+ }
+ }
+ }
+ }
+
+ public void testPQ() {
+ testPQ(atLeast(10000), random());
+ }
+
+ public static void testPQ(int count, Random gen) {
+ LongHeap pq = LongHeap.create(MIN, count);
+ long sum = 0, sum2 = 0;
+
+ for (int i = 0; i < count; i++) {
+ long next = gen.nextLong();
+ sum += next;
+ pq.push(next);
+ }
+
+ long last = Long.MIN_VALUE;
+ for (long i = 0; i < count; i++) {
+ long next = pq.pop();
+ assertTrue(next >= last);
+ last = next;
+ sum2 += last;
+ }
+
+ assertEquals(sum, sum2);
+ }
+
+ public void testClear() {
+ LongHeap pq = LongHeap.create(MIN, 3);
+ pq.push(2);
+ pq.push(3);
+ pq.push(1);
+ assertEquals(3, pq.size());
+ pq.clear();
+ assertEquals(0, pq.size());
+ }
+
+ public void testExceedBounds() {
+ LongHeap pq = LongHeap.create(MIN, 1);
+ pq.push(2);
+ expectThrows(ArrayIndexOutOfBoundsException.class, () -> pq.push(0));
+ assertEquals(2, pq.size()); // the heap is unusable at this point
+ }
+
+ public void testFixedSize() {
+ LongHeap pq = LongHeap.create(MIN, 3);
+ pq.insertWithOverflow(2);
+ pq.insertWithOverflow(3);
+ pq.insertWithOverflow(1);
+ pq.insertWithOverflow(5);
+ pq.insertWithOverflow(7);
+ pq.insertWithOverflow(1);
+ assertEquals(3, pq.size());
+ assertEquals(3, pq.top());
+ }
+
+ public void testFixedSizeMax() {
+ LongHeap pq = LongHeap.create(MAX, 3);
+ pq.insertWithOverflow(2);
+ pq.insertWithOverflow(3);
+ pq.insertWithOverflow(1);
+ pq.insertWithOverflow(5);
+ pq.insertWithOverflow(7);
+ pq.insertWithOverflow(1);
+ assertEquals(3, pq.size());
+ assertEquals(2, pq.top());
+ }
+
+ public void testDuplicateValues() {
+ LongHeap pq = LongHeap.create(MIN, 3);
+ pq.push(2);
+ pq.push(3);
+ pq.push(1);
+ assertEquals(1, pq.top());
+ pq.updateTop(3);
+ assertEquals(3, pq.size());
+ assertArrayEquals(new long[]{0, 2, 3, 3}, pq.getHeapArray());
+ }
+
+ public void testInsertions() {
+ Random random = random();
+ int numDocsInPQ = TestUtil.nextInt(random, 1, 100);
+ AssertingLongHeap pq = new AssertingLongHeap(numDocsInPQ);
+ Long lastLeast = null;
+
+ // Basic insertion of new content
+ ArrayList<Long> sds = new ArrayList<Long>(numDocsInPQ);
+ for (int i = 0; i < numDocsInPQ * 10; i++) {
+ long newEntry = Math.abs(random.nextLong());
+ sds.add(newEntry);
+ pq.insertWithOverflow(newEntry);
+ pq.checkValidity();
+ long newLeast = pq.top();
+ if ((lastLeast != null) && (newLeast != newEntry)
+ && (newLeast != lastLeast)) {
+ // If there has been a change of least entry and it wasn't our new
+ // addition we expect the scores to increase
+ assertTrue(newLeast <= newEntry);
+ assertTrue(newLeast >= lastLeast);
+ }
+ lastLeast = newLeast;
+ }
+ }
+
+ public void testIteratorEmpty() {
+ LongHeap queue = LongHeap.create(MIN, 3);
+ LongHeap.LongIterator it = queue.iterator();
+ assertFalse(it.hasNext());
+ expectThrows(IllegalStateException.class, () -> {
+ it.next();
+ });
+ }
+
+ public void testIteratorOne() {
+ LongHeap queue = LongHeap.create(MIN, 3);
+
+ queue.push(1);
+ LongHeap.LongIterator it = queue.iterator();
+ assertTrue(it.hasNext());
+ assertEquals(1, it.next());
+ assertFalse(it.hasNext());
+ expectThrows(IllegalStateException.class, () -> {
+ it.next();
+ });
+ }
+
+ public void testIteratorTwo() {
+ LongHeap queue = LongHeap.create(MIN, 3);
+
+ queue.push(1);
+ queue.push(2);
+ LongHeap.LongIterator it = queue.iterator();
+ assertTrue(it.hasNext());
+ assertEquals(1, it.next());
+ assertTrue(it.hasNext());
+ assertEquals(2, it.next());
+ assertFalse(it.hasNext());
+ expectThrows(IllegalStateException.class, () -> {
+ it.next();
+ });
+ }
+
+ public void testIteratorRandom() {
+ LongHeap.Order order;
+ if (random().nextBoolean()) {
+ order = MIN;
+ } else {
+ order = MAX;
+ }
+ final int maxSize = TestUtil.nextInt(random(), 1, 20);
+ LongHeap queue = LongHeap.create(order, maxSize);
+ final int iters = atLeast(100);
+ final List<Long> expected = new ArrayList<>();
+ for (int iter = 0; iter < iters; ++iter) {
+ if (queue.size() == 0 || (queue.size() < maxSize && random().nextBoolean())) {
+ final long value = random().nextInt(10);
+ queue.push(value);
+ expected.add(value);
+ } else {
+ expected.remove(Long.valueOf(queue.pop()));
+ }
+ List<Long> actual = new ArrayList<>();
+ LongHeap.LongIterator it = queue.iterator();
+ while (it.hasNext()) {
+ actual.add(it.next());
+ }
+ CollectionUtil.introSort(expected);
+ CollectionUtil.introSort(actual);
+ assertEquals(expected, actual);
+ }
+ }
+
+ public void testUnbounded() {
+ LongHeap pq = LongHeap.create(MAX, -1);
+ int num = random().nextInt(100) + 1;
+ long maxValue = Long.MIN_VALUE;
+ for (int i = 0; i < num; i++) {
+ long value = random().nextLong();
+ if (random().nextBoolean()) {
+ pq.push(value);
+ } else {
+ pq.insertWithOverflow(value);
+ }
+ maxValue = Math.max(maxValue, value);
+ }
+ assertEquals(num, pq.size());
+ assertEquals(maxValue, pq.top());
+ long last = maxValue;
+ int count = 0;
+ while (pq.size() > 0) {
+ long next = pq.pop();
+ ++ count;
+ assertTrue(next <= last);
+ last = next;
+ }
+ assertEquals(num, count);
+ }
+
+}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestNumericUtils.java b/lucene/core/src/test/org/apache/lucene/util/TestNumericUtils.java
index 7dbbfd2..08567a5 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestNumericUtils.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestNumericUtils.java
@@ -488,4 +488,5 @@ public class TestNumericUtils extends LuceneTestCase {
Integer.signum(left.compareTo(right)));
}
}
+
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
index fde3a8b..e078a04 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
@@ -18,6 +18,8 @@ package org.apache.lucene.util;
public class TestVectorUtil extends LuceneTestCase {
+ public static final double DELTA = 1e-4;
+
public void testBasicDotProduct() {
assertEquals(5, VectorUtil.dotProduct(new float[]{1, 2, 3}, new float[]{-10, 0, 5}), 0);
}
@@ -25,7 +27,7 @@ public class TestVectorUtil extends LuceneTestCase {
public void testSelfDotProduct() {
// the dot product of a vector with itself is equal to the sum of the squares of its components
float[] v = randomVector();
- assertEquals(l2(v), VectorUtil.dotProduct(v, v), 1e-4);
+ assertEquals(l2(v), VectorUtil.dotProduct(v, v), DELTA);
}
public void testOrthogonalDotProduct() {
@@ -36,24 +38,46 @@ public class TestVectorUtil extends LuceneTestCase {
float[] u = new float[2];
u[0] = v[1];
u[1] = -v[0];
- assertEquals(0, VectorUtil.dotProduct(u, v), 1e-4);
+ assertEquals(0, VectorUtil.dotProduct(u, v), DELTA);
+ }
+
+ public void testDotProductThrowsForDimensionMismatch() {
+ float[] v = {1, 0, 0}, u = {0, 1};
+ expectThrows(IllegalArgumentException.class, () -> VectorUtil.dotProduct(u, v));
}
public void testSelfSquareDistance() {
// the l2 distance of a vector with itself is zero
float[] v = randomVector();
- assertEquals(0, VectorUtil.squareDistance(v, v), 1e-4);
+ assertEquals(0, VectorUtil.squareDistance(v, v), DELTA);
}
public void testBasicSquareDistance() {
assertEquals(12, VectorUtil.squareDistance(new float[]{1, 2, 3}, new float[]{-1, 0, 5}), 0);
}
+ public void testSquareDistanceThrowsForDimensionMismatch() {
+ float[] v = {1, 0, 0}, u = {0, 1};
+ expectThrows(IllegalArgumentException.class, () -> VectorUtil.squareDistance(u, v));
+ }
+
public void testRandomSquareDistance() {
// the square distance of a vector with its inverse is equal to four times the sum of squares of its components
float[] v = randomVector();
float[] u = negative(v);
- assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), 1e-4);
+ assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA);
+ }
+
+ public void testNormalize() {
+ float[] v = randomVector();
+ v[random().nextInt(v.length)] = 1; // ensure vector is not all zeroes
+ VectorUtil.l2normalize(v);
+ assertEquals(1f, l2(v), DELTA);
+ }
+
+ public void testNormalizeZeroThrows() {
+ float[] v = {0, 0, 0};
+ expectThrows(IllegalArgumentException.class, () -> VectorUtil.l2normalize(v));
}
private float l2(float[] v) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index ce8a6ed..9716f4b 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -165,10 +165,16 @@ public class KnnGraphTester {
usage();
}
if (reindex) {
+ if (docVectorsPath == null) {
+ throw new IllegalArgumentException("-docs argument is required when indexing");
+ }
reindexTimeMsec = createIndex(Paths.get(docVectorsPath), indexPath);
}
switch (operation) {
case "-search":
+ if (docVectorsPath == null) {
+ throw new IllegalArgumentException("-docs argument is required when searching");
+ }
testSearch(indexPath, Paths.get(queryPath), getNN(Paths.get(docVectorsPath), Paths.get(queryPath)));
break;
case "-forceMerge":
@@ -420,16 +426,16 @@ public class KnnGraphTester {
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
- Neighbors queue = Neighbors.create(topK, SEARCH_STRATEGY.reversed);
+ Neighbors queue = Neighbors.create(topK, SEARCH_STRATEGY);
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = SEARCH_STRATEGY.compare(query, vector);
- queue.insertWithOverflow(new Neighbor(j, d));
+ queue.insertWithOverflow(j, d);
}
result[i] = new int[topK];
for (int k = topK - 1; k >= 0; k--) {
- Neighbor n = queue.pop();
- result[i][k] = n.node();
+ result[i][k] = queue.topNode();
+ queue.pop();
//System.out.print(" " + n);
}
if (quiet == false && (i + 1) % 10 == 0) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
index 8f50a1d..ce6fb5c 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java
@@ -41,6 +41,7 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util.VectorUtil;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@@ -104,42 +105,43 @@ public class TestHnsw extends LuceneTestCase {
// run some searches
Neighbors nn = HnswGraph.search(new float[]{1, 0}, 10, 5, vectors.randomAccess(), hnsw.getGraphValues(), random());
int sum = 0;
- for (Neighbor n : nn) {
- sum += n.node();
+ Neighbors.NeighborIterator it = nn.iterator();
+ for (int node = it.next(); node != NO_MORE_DOCS; node = it.next()) {
+ sum += node;
}
// We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) = 45
assertTrue("sum(result docs)=" + sum, sum < 75);
}
- public void testMaxConnections() throws Exception {
+ public void testMaxConnections() {
// verify that maxConnections is observed, and that the retained arcs point to the best-scoring neighbors
HnswGraph graph = new HnswGraph(1, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW);
- graph.connectNodes(0, 1, 1);
- assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
- assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
- graph.connectNodes(0, 2, 2);
- assertArrayEquals(new int[]{2}, graph.getNeighbors(0));
- assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
- assertArrayEquals(new int[]{0}, graph.getNeighbors(2));
- graph.connectNodes(2, 3, 1);
- assertArrayEquals(new int[]{2}, graph.getNeighbors(0));
- assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
- assertArrayEquals(new int[]{0}, graph.getNeighbors(2));
- assertArrayEquals(new int[]{2}, graph.getNeighbors(3));
+ graph.connectNodes(0, 1, 0);
+ assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0));
+ assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
+ graph.connectNodes(0, 2, 0.4f);
+ assertArrayEquals(new int[]{2}, graph.getNeighborNodes(0));
+ assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
+ assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2));
+ graph.connectNodes(2, 3, 0);
+ assertArrayEquals(new int[]{2}, graph.getNeighborNodes(0));
+ assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
+ assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2));
+ assertArrayEquals(new int[]{2}, graph.getNeighborNodes(3));
graph = new HnswGraph(1, VectorValues.SearchStrategy.EUCLIDEAN_HNSW);
graph.connectNodes(0, 1, 1);
- assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
- assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
+ assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0));
+ assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
graph.connectNodes(0, 2, 2);
- assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
- assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
- assertArrayEquals(new int[]{0}, graph.getNeighbors(2));
+ assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0));
+ assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
+ assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2));
graph.connectNodes(2, 3, 1);
- assertArrayEquals(new int[]{1}, graph.getNeighbors(0));
- assertArrayEquals(new int[]{0}, graph.getNeighbors(1));
- assertArrayEquals(new int[]{3}, graph.getNeighbors(2));
- assertArrayEquals(new int[]{2}, graph.getNeighbors(3));
+ assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0));
+ assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1));
+ assertArrayEquals(new int[]{3}, graph.getNeighborNodes(2));
+ assertArrayEquals(new int[]{2}, graph.getNeighborNodes(3));
}
/** Returns vectors evenly distributed around the unit circle.
@@ -232,11 +234,11 @@ public class TestHnsw extends LuceneTestCase {
for (int node = 0; node < size; node ++) {
g.seek(node);
h.seek(node);
- assertEquals("arcs differ for node " + node, getNeighbors(g), getNeighbors(h));
+ assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h));
}
}
- private Set<Integer> getNeighbors(KnnGraphValues g) throws IOException {
+ private Set<Integer> getNeighborNodes(KnnGraphValues g) throws IOException {
Set<Integer> neighbors = new HashSet<>();
for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) {
neighbors.add(n);
@@ -259,38 +261,22 @@ public class TestHnsw extends LuceneTestCase {
public void testNeighbors() {
// make sure we have the sign correct
- Neighbors nn = Neighbors.create(2, false);
- Neighbor a = new Neighbor(1, 10);
- Neighbor b = new Neighbor(2, 20);
- Neighbor c = new Neighbor(3, 30);
- assertNull(nn.insertWithOverflow(b));
- assertNull(nn.insertWithOverflow(a));
- assertSame(a, nn.insertWithOverflow(c));
- assertEquals(20, (int) nn.top().score());
- assertEquals(20, (int) nn.pop().score());
- assertEquals(30, (int) nn.top().score());
- assertEquals(30, (int) nn.pop().score());
-
- Neighbors fn = Neighbors.create(2, true);
- assertNull(fn.insertWithOverflow(b));
- assertNull(fn.insertWithOverflow(a));
- assertSame(c, fn.insertWithOverflow(c));
- assertEquals(20, (int) fn.top().score());
- assertEquals(20, (int) fn.pop().score());
- assertEquals(10, (int) fn.top().score());
- assertEquals(10, (int) fn.pop().score());
- }
-
- @SuppressWarnings("SelfComparison")
- public void testNeighbor() {
- Neighbor a = new Neighbor(1, 10);
- Neighbor b = new Neighbor(2, 20);
- Neighbor c = new Neighbor(3, 20);
- assertEquals(0, a.compareTo(a));
- assertEquals(-1, a.compareTo(b));
- assertEquals(1, b.compareTo(a));
- assertEquals(1, b.compareTo(c));
- assertEquals(-1, c.compareTo(b));
+ Neighbors nn = Neighbors.create(2, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW);
+ assertTrue(nn.insertWithOverflow(2, 0.5f));
+ assertTrue(nn.insertWithOverflow(1, 0.2f));
+ assertTrue(nn.insertWithOverflow(3, 1f));
+ assertEquals(0.5f, nn.topScore(), 0);
+ nn.pop();
+ assertEquals(1f, nn.topScore(), 0);
+ nn.pop();
+
+ Neighbors fn = Neighbors.create(2, VectorValues.SearchStrategy.EUCLIDEAN_HNSW);
+ assertTrue(fn.insertWithOverflow(2, 2));
+ assertTrue(fn.insertWithOverflow(1, 1));
+ assertFalse(fn.insertWithOverflow(3, 3));
+ assertEquals(2f, fn.topScore(), 0);
+ fn.pop();
+ assertEquals(1f, fn.topScore(), 0);
}
private static float[] randomVector(Random random, int dim) {
@@ -298,6 +284,7 @@ public class TestHnsw extends LuceneTestCase {
for (int i = 0; i < dim; i++) {
vec[i] = random.nextFloat();
}
+ VectorUtil.l2normalize(vec);
return vec;
}
@@ -457,9 +444,9 @@ public class TestHnsw extends LuceneTestCase {
}
public void testHnswGraphBuilderInvalid() {
- expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0));
- expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0));
- expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0));
+ expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0));
+ expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0));
+ expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0));
}
}