You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by pa...@apache.org on 2023/05/16 06:20:37 UTC
[lucene] branch main updated: Optimize HNSW diversity calculation (#12235)
This is an automated email from the ASF dual-hosted git repository.
patrickz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/main by this push:
new 8af305892d7 Optimize HNSW diversity calculation (#12235)
8af305892d7 is described below
commit 8af305892d726c180f03316c73aebf8183c2e481
Author: Patrick Zhai <zh...@users.noreply.github.com>
AuthorDate: Mon May 15 23:20:31 2023 -0700
Optimize HNSW diversity calculation (#12235)
---
lucene/CHANGES.txt | 2 +
.../synonym/word2vec/Word2VecSynonymProvider.java | 2 +-
.../lucene90/Lucene90HnswGraphBuilder.java | 8 +-
.../lucene91/Lucene91HnswGraphBuilder.java | 8 +-
.../apache/lucene/util/hnsw/HnswGraphBuilder.java | 124 ++++++++++++++++-----
.../org/apache/lucene/util/hnsw/NeighborArray.java | 88 ++++++++++++---
.../apache/lucene/util/hnsw/OnHeapHnswGraph.java | 12 +-
.../apache/lucene/util/hnsw/TestNeighborArray.java | 114 ++++++++++++++-----
8 files changed, 272 insertions(+), 86 deletions(-)
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 7876d4bfbf5..308e28f27cd 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -128,6 +128,8 @@ Optimizations
* GITHUB#12286 Toposort use iterator to avoid stackoverflow. (Tang Donghai)
+* GITHUB#12235: Optimize HNSW diversity calculation. (Patrick Zhai)
+
Bug Fixes
---------------------
(No changes)
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
index 1d1f06fc991..3784d9372fe 100644
--- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
@@ -85,7 +85,7 @@ public class Word2VecSynonymProvider {
SIMILARITY_FUNCTION,
hnswGraph,
null,
- word2VecModel.size());
+ Integer.MAX_VALUE);
int size = synonyms.size();
for (int i = 0; i < size; i++) {
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
index e7f16b4f3fc..4b1f7068a5f 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
@@ -183,10 +183,10 @@ public final class Lucene90HnswGraphBuilder {
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node()[i];
- Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
- nbrNbr.add(node, neighbors.score()[i]);
- if (nbrNbr.size() > maxConn) {
- diversityUpdate(nbrNbr);
+ Lucene90NeighborArray nbrsOfNbr = hnsw.getNeighbors(nbr);
+ nbrsOfNbr.add(node, neighbors.score()[i]);
+ if (nbrsOfNbr.size() > maxConn) {
+ diversityUpdate(nbrsOfNbr);
}
}
}
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
index b0e9d160457..c82920181cc 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
@@ -204,10 +204,10 @@ public final class Lucene91HnswGraphBuilder {
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
- Lucene91NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
- nbrNbr.add(node, neighbors.score[i]);
- if (nbrNbr.size() > maxConn) {
- diversityUpdate(nbrNbr);
+ Lucene91NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
+ nbrsOfNbr.add(node, neighbors.score[i]);
+ if (nbrsOfNbr.size() > maxConn) {
+ diversityUpdate(nbrsOfNbr);
}
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index edefb696b8c..e2a57a303c6 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -218,7 +218,9 @@ public final class HnswGraphBuilder<T> {
case BYTE -> this.similarityFunction.compare(
binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor));
};
- newNeighbors.insertSorted(newNeighbor, score);
+ // we are not sure whether the previous graph contains
+ // unchecked nodes, so we have to assume they're all unchecked
+ newNeighbors.addOutOfOrder(newNeighbor, score);
}
}
}
@@ -314,11 +316,11 @@ public final class HnswGraphBuilder<T> {
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
- NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
- nbrNbr.insertSorted(node, neighbors.score[i]);
- if (nbrNbr.size() > maxConnOnLevel) {
- int indexToRemove = findWorstNonDiverse(nbrNbr);
- nbrNbr.removeIndex(indexToRemove);
+ NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
+ nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
+ if (nbrsOfNbr.size() > maxConnOnLevel) {
+ int indexToRemove = findWorstNonDiverse(nbrsOfNbr);
+ nbrsOfNbr.removeIndex(indexToRemove);
}
}
}
@@ -333,7 +335,7 @@ public final class HnswGraphBuilder<T> {
float cScore = candidates.score[i];
assert cNode < hnsw.size();
if (diversityCheck(cNode, cScore, neighbors)) {
- neighbors.add(cNode, cScore);
+ neighbors.addInOrder(cNode, cScore);
}
}
}
@@ -345,7 +347,7 @@ public final class HnswGraphBuilder<T> {
// sorted from worst to best
for (int i = 0; i < candidateCount; i++) {
float maxSimilarity = candidates.topScore();
- scratch.add(candidates.pop(), maxSimilarity);
+ scratch.addInOrder(candidates.pop(), maxSimilarity);
}
}
@@ -400,50 +402,116 @@ public final class HnswGraphBuilder<T> {
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
+ int[] uncheckedIndexes = neighbors.sort();
+ if (uncheckedIndexes == null) {
+ // all nodes are checked, we will directly return the most distant one
+ return neighbors.size() - 1;
+ }
+ int uncheckedCursor = uncheckedIndexes.length - 1;
for (int i = neighbors.size() - 1; i > 0; i--) {
- if (isWorstNonDiverse(i, neighbors)) {
+ if (uncheckedCursor < 0) {
+ // no unchecked node left
+ break;
+ }
+ if (isWorstNonDiverse(i, neighbors, uncheckedIndexes, uncheckedCursor)) {
return i;
}
+ if (i == uncheckedIndexes[uncheckedCursor]) {
+ uncheckedCursor--;
+ }
}
return neighbors.size() - 1;
}
- private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
+ private boolean isWorstNonDiverse(
+ int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor)
throws IOException {
int candidateNode = neighbors.node[candidateIndex];
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
- candidateIndex, (byte[]) vectors.vectorValue(candidateNode), neighbors);
+ candidateIndex,
+ (byte[]) vectors.vectorValue(candidateNode),
+ neighbors,
+ uncheckedIndexes,
+ uncheckedCursor);
case FLOAT32 -> isWorstNonDiverse(
- candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
+ candidateIndex,
+ (float[]) vectors.vectorValue(candidateNode),
+ neighbors,
+ uncheckedIndexes,
+ uncheckedCursor);
};
}
private boolean isWorstNonDiverse(
- int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
+ int candidateIndex,
+ float[] candidateVector,
+ NeighborArray neighbors,
+ int[] uncheckedIndexes,
+ int uncheckedCursor)
+ throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
- for (int i = candidateIndex - 1; i >= 0; i--) {
- float neighborSimilarity =
- similarityFunction.compare(
- candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
- // candidate node is too similar to node i given its score relative to the base node
- if (neighborSimilarity >= minAcceptedSimilarity) {
- return true;
+ if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
+ // the candidate itself is unchecked
+ for (int i = candidateIndex - 1; i >= 0; i--) {
+ float neighborSimilarity =
+ similarityFunction.compare(
+ candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
+ // candidate node is too similar to node i given its score relative to the base node
+ if (neighborSimilarity >= minAcceptedSimilarity) {
+ return true;
+ }
+ }
+ } else {
+ // else we just need to make sure candidate does not violate diversity with the (newly
+ // inserted) unchecked nodes
+ assert candidateIndex > uncheckedIndexes[uncheckedCursor];
+ for (int i = uncheckedCursor; i >= 0; i--) {
+ float neighborSimilarity =
+ similarityFunction.compare(
+ candidateVector,
+ (float[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
+ // candidate node is too similar to node i given its score relative to the base node
+ if (neighborSimilarity >= minAcceptedSimilarity) {
+ return true;
+ }
}
}
return false;
}
private boolean isWorstNonDiverse(
- int candidateIndex, byte[] candidateVector, NeighborArray neighbors) throws IOException {
+ int candidateIndex,
+ byte[] candidateVector,
+ NeighborArray neighbors,
+ int[] uncheckedIndexes,
+ int uncheckedCursor)
+ throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
- for (int i = candidateIndex - 1; i >= 0; i--) {
- float neighborSimilarity =
- similarityFunction.compare(
- candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
- // candidate node is too similar to node i given its score relative to the base node
- if (neighborSimilarity >= minAcceptedSimilarity) {
- return true;
+ if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
+ // the candidate itself is unchecked
+ for (int i = candidateIndex - 1; i >= 0; i--) {
+ float neighborSimilarity =
+ similarityFunction.compare(
+ candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
+ // candidate node is too similar to node i given its score relative to the base node
+ if (neighborSimilarity >= minAcceptedSimilarity) {
+ return true;
+ }
+ }
+ } else {
+ // else we just need to make sure candidate does not violate diversity with the (newly
+ // inserted) unchecked nodes
+ assert candidateIndex > uncheckedIndexes[uncheckedCursor];
+ for (int i = uncheckedCursor; i >= 0; i--) {
+ float neighborSimilarity =
+ similarityFunction.compare(
+ candidateVector,
+ (byte[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
+ // candidate node is too similar to node i given its score relative to the base node
+ if (neighborSimilarity >= minAcceptedSimilarity) {
+ return true;
+ }
}
}
return false;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
index ec1b5ec3e89..a23b9b5254e 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
@@ -34,6 +34,7 @@ public class NeighborArray {
float[] score;
int[] node;
+ private int sortedNodeSize;
public NeighborArray(int maxSize, boolean descOrder) {
node = new int[maxSize];
@@ -43,9 +44,10 @@ public class NeighborArray {
/**
* Add a new node to the NeighborArray. The new node must be worse than all previously stored
- * nodes.
+ * nodes. This cannot be called after {@link #addOutOfOrder(int, float)}
*/
- public void add(int newNode, float newScore) {
+ public void addInOrder(int newNode, float newScore) {
+ assert size == sortedNodeSize : "cannot call addInOrder after addOutOfOrder";
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
@@ -59,23 +61,72 @@ public class NeighborArray {
node[size] = newNode;
score[size] = newScore;
++size;
+ ++sortedNodeSize;
}
- /** Add a new node to the NeighborArray into a correct sort position according to its score. */
- public void insertSorted(int newNode, float newScore) {
+ /** Add node and score but do not insert as sorted */
+ public void addOutOfOrder(int newNode, float newScore) {
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
}
+ node[size] = newNode;
+ score[size] = newScore;
+ size++;
+ }
+
+ /**
+ * Sort the array according to scores, and return the sorted indexes of previous unsorted nodes
+ * (unchecked nodes)
+ *
+ * @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is
+ * already fully sorted
+ */
+ public int[] sort() {
+ if (size == sortedNodeSize) {
+ // all nodes checked and sorted
+ return null;
+ }
+ assert sortedNodeSize < size;
+ int[] uncheckedIndexes = new int[size - sortedNodeSize];
+ int count = 0;
+ while (sortedNodeSize != size) {
+ uncheckedIndexes[count] = insertSortedInternal(); // sortedNodeSize is increased inside
+ for (int i = 0; i < count; i++) {
+ if (uncheckedIndexes[i] >= uncheckedIndexes[count]) {
+ // the previous inserted nodes has been shifted
+ uncheckedIndexes[i]++;
+ }
+ }
+ count++;
+ }
+ Arrays.sort(uncheckedIndexes);
+ return uncheckedIndexes;
+ }
+
+ /** insert the first unsorted node into its sorted position */
+ private int insertSortedInternal() {
+ assert sortedNodeSize < size : "Call this method only when there's unsorted node";
+ int tmpNode = node[sortedNodeSize];
+ float tmpScore = score[sortedNodeSize];
int insertionPoint =
scoresDescOrder
- ? descSortFindRightMostInsertionPoint(newScore)
- : ascSortFindRightMostInsertionPoint(newScore);
- System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint);
- System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint);
- node[insertionPoint] = newNode;
- score[insertionPoint] = newScore;
- ++size;
+ ? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize)
+ : ascSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize);
+ System.arraycopy(
+ node, insertionPoint, node, insertionPoint + 1, sortedNodeSize - insertionPoint);
+ System.arraycopy(
+ score, insertionPoint, score, insertionPoint + 1, sortedNodeSize - insertionPoint);
+ node[insertionPoint] = tmpNode;
+ score[insertionPoint] = tmpScore;
+ ++sortedNodeSize;
+ return insertionPoint;
+ }
+
+ /** This method is for test only. */
+ void insertSorted(int newNode, float newScore) {
+ addOutOfOrder(newNode, newScore);
+ insertSortedInternal();
}
public int size() {
@@ -97,15 +148,20 @@ public class NeighborArray {
public void clear() {
size = 0;
+ sortedNodeSize = 0;
}
public void removeLast() {
size--;
+ sortedNodeSize = Math.min(sortedNodeSize, size);
}
public void removeIndex(int idx) {
System.arraycopy(node, idx + 1, node, idx, size - idx - 1);
System.arraycopy(score, idx + 1, score, idx, size - idx - 1);
+ if (idx < sortedNodeSize) {
+ sortedNodeSize--;
+ }
size--;
}
@@ -114,11 +170,11 @@ public class NeighborArray {
return "NeighborArray[" + size + "]";
}
- private int ascSortFindRightMostInsertionPoint(float newScore) {
- int insertionPoint = Arrays.binarySearch(score, 0, size, newScore);
+ private int ascSortFindRightMostInsertionPoint(float newScore, int bound) {
+ int insertionPoint = Arrays.binarySearch(score, 0, bound, newScore);
if (insertionPoint >= 0) {
// find the right most position with the same score
- while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
+ while ((insertionPoint < bound - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
insertionPoint++;
}
insertionPoint++;
@@ -128,9 +184,9 @@ public class NeighborArray {
return insertionPoint;
}
- private int descSortFindRightMostInsertionPoint(float newScore) {
+ private int descSortFindRightMostInsertionPoint(float newScore, int bound) {
int start = 0;
- int end = size - 1;
+ int end = bound - 1;
while (start <= end) {
int mid = (start + end) / 2;
if (score[mid] < newScore) end = mid - 1;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index ae39614f160..b7f0ecfd075 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -171,14 +171,14 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
public long ramBytesUsed() {
long neighborArrayBytes0 =
nsize0 * (Integer.BYTES + Float.BYTES)
- + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
- + RamUsageEstimator.NUM_BYTES_OBJECT_REF
- + Integer.BYTES * 2;
+ + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
+ + Integer.BYTES * 3;
long neighborArrayBytes =
nsize * (Integer.BYTES + Float.BYTES)
- + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
- + RamUsageEstimator.NUM_BYTES_OBJECT_REF
- + Integer.BYTES * 2;
+ + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
+ + Integer.BYTES * 3;
long total = 0;
for (int l = 0; l < numLevels; l++) {
if (l == 0) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java
index b8ae24f6200..039f69c9dc4 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java
@@ -23,100 +23,160 @@ public class TestNeighborArray extends LuceneTestCase {
public void testScoresDescOrder() {
NeighborArray neighbors = new NeighborArray(10, true);
- neighbors.add(0, 1);
- neighbors.add(1, 0.8f);
+ neighbors.addInOrder(0, 1);
+ neighbors.addInOrder(1, 0.8f);
- AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.9f));
+ AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
neighbors.insertSorted(3, 0.9f);
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
- asserNodesEqual(new int[] {0, 3, 1}, neighbors);
+ assertNodesEqual(new int[] {0, 3, 1}, neighbors);
neighbors.insertSorted(4, 1f);
assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors);
- asserNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
+ assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
neighbors.insertSorted(5, 1.1f);
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors);
- asserNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
+ assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
neighbors.insertSorted(6, 0.8f);
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors);
- asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
+ assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
neighbors.insertSorted(7, 0.8f);
assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
- asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
+ assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
neighbors.removeIndex(2);
assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
- asserNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
+ assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
neighbors.removeIndex(0);
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
- asserNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
+ assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
neighbors.removeIndex(4);
assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors);
- asserNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
+ assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
neighbors.removeLast();
assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
- asserNodesEqual(new int[] {0, 3, 1}, neighbors);
+ assertNodesEqual(new int[] {0, 3, 1}, neighbors);
neighbors.insertSorted(8, 0.9f);
assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors);
- asserNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
+ assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
}
public void testScoresAscOrder() {
NeighborArray neighbors = new NeighborArray(10, false);
- neighbors.add(0, 0.1f);
- neighbors.add(1, 0.3f);
+ neighbors.addInOrder(0, 0.1f);
+ neighbors.addInOrder(1, 0.3f);
- AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.15f));
+ AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f));
assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
neighbors.insertSorted(3, 0.3f);
assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);
- asserNodesEqual(new int[] {0, 1, 3}, neighbors);
+ assertNodesEqual(new int[] {0, 1, 3}, neighbors);
neighbors.insertSorted(4, 0.2f);
assertScoresEqual(new float[] {0.1f, 0.2f, 0.3f, 0.3f}, neighbors);
- asserNodesEqual(new int[] {0, 4, 1, 3}, neighbors);
+ assertNodesEqual(new int[] {0, 4, 1, 3}, neighbors);
neighbors.insertSorted(5, 0.05f);
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.3f, 0.3f}, neighbors);
- asserNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors);
+ assertNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors);
neighbors.insertSorted(6, 0.2f);
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
- asserNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors);
+ assertNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors);
neighbors.insertSorted(7, 0.2f);
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
- asserNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors);
+ assertNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors);
neighbors.removeIndex(2);
assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
- asserNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors);
+ assertNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors);
neighbors.removeIndex(0);
assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
- asserNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors);
+ assertNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors);
neighbors.removeIndex(4);
assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f}, neighbors);
- asserNodesEqual(new int[] {0, 6, 7, 1}, neighbors);
+ assertNodesEqual(new int[] {0, 6, 7, 1}, neighbors);
neighbors.removeLast();
assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f}, neighbors);
- asserNodesEqual(new int[] {0, 6, 7}, neighbors);
+ assertNodesEqual(new int[] {0, 6, 7}, neighbors);
neighbors.insertSorted(8, 0.01f);
assertScoresEqual(new float[] {0.01f, 0.1f, 0.2f, 0.2f}, neighbors);
- asserNodesEqual(new int[] {8, 0, 6, 7}, neighbors);
+ assertNodesEqual(new int[] {8, 0, 6, 7}, neighbors);
+ }
+
+ public void testSortAsc() {
+ NeighborArray neighbors = new NeighborArray(10, false);
+ neighbors.addOutOfOrder(1, 2);
+ // we disallow calling addInOrder after addOutOfOrder even if they're actual in order
+ expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2));
+ neighbors.addOutOfOrder(2, 3);
+ neighbors.addOutOfOrder(5, 6);
+ neighbors.addOutOfOrder(3, 4);
+ neighbors.addOutOfOrder(7, 8);
+ neighbors.addOutOfOrder(6, 7);
+ neighbors.addOutOfOrder(4, 5);
+ int[] unchecked = neighbors.sort();
+ assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
+ assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
+ assertScoresEqual(new float[] {2, 3, 4, 5, 6, 7, 8}, neighbors);
+
+ NeighborArray neighbors2 = new NeighborArray(10, false);
+ neighbors2.addInOrder(0, 1);
+ neighbors2.addInOrder(1, 2);
+ neighbors2.addInOrder(4, 5);
+ neighbors2.addOutOfOrder(2, 3);
+ neighbors2.addOutOfOrder(6, 7);
+ neighbors2.addOutOfOrder(5, 6);
+ neighbors2.addOutOfOrder(3, 4);
+ unchecked = neighbors2.sort();
+ assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked);
+ assertNodesEqual(new int[] {0, 1, 2, 3, 4, 5, 6}, neighbors2);
+ assertScoresEqual(new float[] {1, 2, 3, 4, 5, 6, 7}, neighbors2);
+ }
+
+ public void testSortDesc() {
+ NeighborArray neighbors = new NeighborArray(10, true);
+ neighbors.addOutOfOrder(1, 7);
+ // we disallow calling addInOrder after addOutOfOrder even if they're actual in order
+ expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2));
+ neighbors.addOutOfOrder(2, 6);
+ neighbors.addOutOfOrder(5, 3);
+ neighbors.addOutOfOrder(3, 5);
+ neighbors.addOutOfOrder(7, 1);
+ neighbors.addOutOfOrder(6, 2);
+ neighbors.addOutOfOrder(4, 4);
+ int[] unchecked = neighbors.sort();
+ assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
+ assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
+ assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
+
+ NeighborArray neighbors2 = new NeighborArray(10, true);
+ neighbors2.addInOrder(1, 7);
+ neighbors2.addInOrder(2, 6);
+ neighbors2.addInOrder(5, 3);
+ neighbors2.addOutOfOrder(3, 5);
+ neighbors2.addOutOfOrder(7, 1);
+ neighbors2.addOutOfOrder(6, 2);
+ neighbors2.addOutOfOrder(4, 4);
+ unchecked = neighbors2.sort();
+ assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked);
+ assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors2);
+ assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors2);
}
private void assertScoresEqual(float[] scores, NeighborArray neighbors) {
@@ -125,7 +185,7 @@ public class TestNeighborArray extends LuceneTestCase {
}
}
- private void asserNodesEqual(int[] nodes, NeighborArray neighbors) {
+ private void assertNodesEqual(int[] nodes, NeighborArray neighbors) {
for (int i = 0; i < nodes.length; i++) {
assertEquals(nodes[i], neighbors.node[i]);
}