You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by so...@apache.org on 2022/09/19 15:49:04 UTC
[lucene] branch main updated: Diversity check bugfix (#11781)
This is an automated email from the ASF dual-hosted git repository.
sokolov pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/main by this push:
new 07af358f90d Diversity check bugfix (#11781)
07af358f90d is described below
commit 07af358f90d296cc2d224692b4b32082ee8e577a
Author: Michael Sokolov <so...@falutin.net>
AuthorDate: Mon Sep 19 11:48:59 2022 -0400
Diversity check bugfix (#11781)
* Fixes bug in HNSW diversity checks introduced in LUCENE-10577
---
.../apache/lucene/util/hnsw/HnswGraphBuilder.java | 42 ++++++-------
.../org/apache/lucene/util/hnsw/TestHnswGraph.java | 73 ++++++++++++++++++++++
2 files changed, 94 insertions(+), 21 deletions(-)
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index d5ed059b0c2..cc8e330ffc8 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -109,7 +109,7 @@ public final class HnswGraphBuilder<T> {
this.M = M;
this.beamWidth = beamWidth;
// normalization factor for level generation; currently not configurable
- this.ml = 1 / Math.log(1.0 * M);
+ this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
@@ -316,49 +316,49 @@ public final class HnswGraphBuilder<T> {
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
for (int i = neighbors.size() - 1; i > 0; i--) {
- if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) {
+ if (isWorstNonDiverse(i, neighbors)) {
return i;
}
}
return neighbors.size() - 1;
}
- private boolean isWorstNonDiverse(
- int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
+ private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
+ throws IOException {
+ int candidateNode = neighbors.node[candidateIndex];
return switch (vectorEncoding) {
- case BYTE -> isWorstNonDiverse(
- candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity);
+ case BYTE -> isWorstNonDiverse(candidateIndex, vectors.binaryValue(candidateNode), neighbors);
case FLOAT32 -> isWorstNonDiverse(
- candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity);
+ candidateIndex, vectors.vectorValue(candidateNode), neighbors);
};
}
private boolean isWorstNonDiverse(
- int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity)
- throws IOException {
- for (int i = candidateIndex - 1; i > -0; i--) {
+ int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
+ float minAcceptedSimilarity = neighbors.score[candidateIndex];
+ for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
- similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
- // node i is too similar to node j given its score relative to the base node
+ similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i]));
+ // candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
- return false;
+ return true;
}
}
- return true;
+ return false;
}
private boolean isWorstNonDiverse(
- int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity)
- throws IOException {
- for (int i = candidateIndex - 1; i > -0; i--) {
+ int candidateIndex, BytesRef candidateVector, NeighborArray neighbors) throws IOException {
+ float minAcceptedSimilarity = neighbors.score[candidateIndex];
+ for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
- similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
- // node i is too similar to node j given its score relative to the base node
+ similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i]));
+ // candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
- return false;
+ return true;
}
}
- return true;
+ return false;
}
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index 7852e706157..16d1996820a 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -504,6 +504,7 @@ public class TestHnswGraph extends LuceneTestCase {
unitVector2d(0.9),
unitVector2d(0.8),
unitVector2d(0.77),
+ unitVector2d(0.6)
};
if (vectorEncoding == VectorEncoding.BYTE) {
for (float[] v : values) {
@@ -555,6 +556,78 @@ public class TestHnswGraph extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
}
+ public void testDiversityFallback() throws IOException {
+ vectorEncoding = randomVectorEncoding();
+ similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+ // Some test cases can't be exercised in two dimensions;
+ // in particular if a new neighbor displaces an existing neighbor
+ // by being closer to the target, yet none of the existing neighbors is closer to the new vector
+ // than to the target -- ie they all remain diverse, so we simply drop the farthest one.
+ float[][] values = {
+ {0, 0, 0},
+ {0, 10, 0},
+ {0, 0, 20},
+ {10, 0, 0},
+ {0, 4, 0}
+ };
+ MockVectorValues vectors = new MockVectorValues(values);
+ // First add nodes until everybody gets a full neighbor list
+ HnswGraphBuilder<?> builder =
+ HnswGraphBuilder.create(
+ vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
+ // node 0 is added by the builder constructor
+ // builder.addGraphNode(vectors.vectorValue(0));
+ RandomAccessVectorValues vectorsCopy = vectors.copy();
+ builder.addGraphNode(1, vectorsCopy);
+ builder.addGraphNode(2, vectorsCopy);
+ assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
+ // 2 is closer to 0 than 1, so it is excluded as non-diverse
+ assertLevel0Neighbors(builder.hnsw, 1, 0);
+ // 1 is closer to 0 than 2, so it is excluded as non-diverse
+ assertLevel0Neighbors(builder.hnsw, 2, 0);
+
+ builder.addGraphNode(3, vectorsCopy);
+ // this is one case we are testing; 2 has been displaced by 3
+ assertLevel0Neighbors(builder.hnsw, 0, 1, 3);
+ assertLevel0Neighbors(builder.hnsw, 1, 0);
+ assertLevel0Neighbors(builder.hnsw, 2, 0);
+ assertLevel0Neighbors(builder.hnsw, 3, 0);
+ }
+
+ public void testDiversity3d() throws IOException {
+ vectorEncoding = randomVectorEncoding();
+ similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+ // test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
+ float[][] values = {
+ {0, 0, 0},
+ {0, 10, 0},
+ {0, 0, 20},
+ {0, 9, 0}
+ };
+ MockVectorValues vectors = new MockVectorValues(values);
+ // First add nodes until everybody gets a full neighbor list
+ HnswGraphBuilder<?> builder =
+ HnswGraphBuilder.create(
+ vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
+ // node 0 is added by the builder constructor
+ // builder.addGraphNode(vectors.vectorValue(0));
+ RandomAccessVectorValues vectorsCopy = vectors.copy();
+ builder.addGraphNode(1, vectorsCopy);
+ builder.addGraphNode(2, vectorsCopy);
+ assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
+ // 2 is closer to 0 than 1, so it is excluded as non-diverse
+ assertLevel0Neighbors(builder.hnsw, 1, 0);
+ // 1 is closer to 0 than 2, so it is excluded as non-diverse
+ assertLevel0Neighbors(builder.hnsw, 2, 0);
+
+ builder.addGraphNode(3, vectorsCopy);
+ // this is one case we are testing; 1 has been displaced by 3
+ assertLevel0Neighbors(builder.hnsw, 0, 2, 3);
+ assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
+ assertLevel0Neighbors(builder.hnsw, 2, 0);
+ assertLevel0Neighbors(builder.hnsw, 3, 0, 1);
+ }
+
private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) {
Arrays.sort(expected);
NeighborArray nn = graph.getNeighbors(0, node);