You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by so...@apache.org on 2022/09/19 15:55:44 UTC

[lucene] branch branch_9x updated: Diversity check bugfix (#11781)

This is an automated email from the ASF dual-hosted git repository.

sokolov pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new 3393b63c942 Diversity check bugfix (#11781)
3393b63c942 is described below

commit 3393b63c942a5a50edcd5a24fd3edad2766ba673
Author: Michael Sokolov <so...@falutin.net>
AuthorDate: Mon Sep 19 11:48:59 2022 -0400

    Diversity check bugfix (#11781)
    
    * Fixes bug in HNSW diversity checks introduced in LUCENE-10577
---
 .../apache/lucene/util/hnsw/HnswGraphBuilder.java  | 43 +++++++------
 .../org/apache/lucene/util/hnsw/TestHnswGraph.java | 73 ++++++++++++++++++++++
 2 files changed, 94 insertions(+), 22 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index 282db736144..bd14ac6016c 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -109,7 +109,7 @@ public final class HnswGraphBuilder<T> {
     this.M = M;
     this.beamWidth = beamWidth;
     // normalization factor for level generation; currently not configurable
-    this.ml = 1 / Math.log(1.0 * M);
+    this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
     this.random = new SplittableRandom(seed);
     int levelOfFirstNode = getRandomGraphLevel(ml, random);
     this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
@@ -322,52 +322,51 @@ public final class HnswGraphBuilder<T> {
    */
   private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
     for (int i = neighbors.size() - 1; i > 0; i--) {
-      if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) {
+      if (isWorstNonDiverse(i, neighbors)) {
         return i;
       }
     }
     return neighbors.size() - 1;
   }
 
-  private boolean isWorstNonDiverse(
-      int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
+  private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
+      throws IOException {
+    int candidateNode = neighbors.node[candidateIndex];
     switch (vectorEncoding) {
       case BYTE:
-        return isWorstNonDiverse(
-            candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity);
+        return isWorstNonDiverse(candidateIndex, vectors.binaryValue(candidateNode), neighbors);
       default:
       case FLOAT32:
-        return isWorstNonDiverse(
-            candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity);
+        return isWorstNonDiverse(candidateIndex, vectors.vectorValue(candidateNode), neighbors);
     }
   }
 
   private boolean isWorstNonDiverse(
-      int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity)
-      throws IOException {
-    for (int i = candidateIndex - 1; i > -0; i--) {
+      int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
+    float minAcceptedSimilarity = neighbors.score[candidateIndex];
+    for (int i = candidateIndex - 1; i >= 0; i--) {
       float neighborSimilarity =
-          similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
-      // node i is too similar to node j given its score relative to the base node
+          similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i]));
+      // candidate node is too similar to node i given its score relative to the base node
       if (neighborSimilarity >= minAcceptedSimilarity) {
-        return false;
+        return true;
       }
     }
-    return true;
+    return false;
   }
 
   private boolean isWorstNonDiverse(
-      int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity)
-      throws IOException {
-    for (int i = candidateIndex - 1; i > -0; i--) {
+      int candidateIndex, BytesRef candidateVector, NeighborArray neighbors) throws IOException {
+    float minAcceptedSimilarity = neighbors.score[candidateIndex];
+    for (int i = candidateIndex - 1; i >= 0; i--) {
       float neighborSimilarity =
-          similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
-      // node i is too similar to node j given its score relative to the base node
+          similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i]));
+      // candidate node is too similar to node i given its score relative to the base node
       if (neighborSimilarity >= minAcceptedSimilarity) {
-        return false;
+        return true;
       }
     }
-    return true;
+    return false;
   }
 
   private static int getRandomGraphLevel(double ml, SplittableRandom random) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index 1c88a728c96..7cb192b9367 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -504,6 +504,7 @@ public class TestHnswGraph extends LuceneTestCase {
       unitVector2d(0.9),
       unitVector2d(0.8),
       unitVector2d(0.77),
+      unitVector2d(0.6)
     };
     if (vectorEncoding == VectorEncoding.BYTE) {
       for (float[] v : values) {
@@ -555,6 +556,78 @@ public class TestHnswGraph extends LuceneTestCase {
     assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
   }
 
+  public void testDiversityFallback() throws IOException {
+    vectorEncoding = randomVectorEncoding();
+    similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+    // Some test cases can't be exercised in two dimensions;
+    // in particular if a new neighbor displaces an existing neighbor
+    // by being closer to the target, yet none of the existing neighbors is closer to the new vector
+    // than to the target -- ie they all remain diverse, so we simply drop the farthest one.
+    float[][] values = {
+      {0, 0, 0},
+      {0, 10, 0},
+      {0, 0, 20},
+      {10, 0, 0},
+      {0, 4, 0}
+    };
+    MockVectorValues vectors = new MockVectorValues(values);
+    // First add nodes until everybody gets a full neighbor list
+    HnswGraphBuilder<?> builder =
+        HnswGraphBuilder.create(
+            vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
+    // node 0 is added by the builder constructor
+    // builder.addGraphNode(vectors.vectorValue(0));
+    RandomAccessVectorValues vectorsCopy = vectors.copy();
+    builder.addGraphNode(1, vectorsCopy);
+    builder.addGraphNode(2, vectorsCopy);
+    assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
+    // 2 is closer to 0 than 1, so it is excluded as non-diverse
+    assertLevel0Neighbors(builder.hnsw, 1, 0);
+    // 1 is closer to 0 than 2, so it is excluded as non-diverse
+    assertLevel0Neighbors(builder.hnsw, 2, 0);
+
+    builder.addGraphNode(3, vectorsCopy);
+    // this is one case we are testing; 2 has been displaced by 3
+    assertLevel0Neighbors(builder.hnsw, 0, 1, 3);
+    assertLevel0Neighbors(builder.hnsw, 1, 0);
+    assertLevel0Neighbors(builder.hnsw, 2, 0);
+    assertLevel0Neighbors(builder.hnsw, 3, 0);
+  }
+
+  public void testDiversity3d() throws IOException {
+    vectorEncoding = randomVectorEncoding();
+    similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+    // test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
+    float[][] values = {
+      {0, 0, 0},
+      {0, 10, 0},
+      {0, 0, 20},
+      {0, 9, 0}
+    };
+    MockVectorValues vectors = new MockVectorValues(values);
+    // First add nodes until everybody gets a full neighbor list
+    HnswGraphBuilder<?> builder =
+        HnswGraphBuilder.create(
+            vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
+    // node 0 is added by the builder constructor
+    // builder.addGraphNode(vectors.vectorValue(0));
+    RandomAccessVectorValues vectorsCopy = vectors.copy();
+    builder.addGraphNode(1, vectorsCopy);
+    builder.addGraphNode(2, vectorsCopy);
+    assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
+    // 2 is closer to 0 than 1, so it is excluded as non-diverse
+    assertLevel0Neighbors(builder.hnsw, 1, 0);
+    // 1 is closer to 0 than 2, so it is excluded as non-diverse
+    assertLevel0Neighbors(builder.hnsw, 2, 0);
+
+    builder.addGraphNode(3, vectorsCopy);
+    // this is one case we are testing; 1 has been displaced by 3
+    assertLevel0Neighbors(builder.hnsw, 0, 2, 3);
+    assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
+    assertLevel0Neighbors(builder.hnsw, 2, 0);
+    assertLevel0Neighbors(builder.hnsw, 3, 0, 1);
+  }
+
   private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) {
     Arrays.sort(expected);
     NeighborArray nn = graph.getNeighbors(0, node);