You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ma...@apache.org on 2022/05/04 18:15:21 UTC

[lucene] branch main updated: LUCENE-9848 Sort HNSW graph neighbors for construction (#862)

This is an automated email from the ASF dual-hosted git repository.

mayya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new dc6a7f94680 LUCENE-9848 Sort HNSW graph neighbors for construction (#862)
dc6a7f94680 is described below

commit dc6a7f94680deba7a5842d35526a966c96faec91
Author: Mayya Sharipova <ma...@elastic.co>
AuthorDate: Wed May 4 14:15:14 2022 -0400

    LUCENE-9848 Sort HNSW graph neighbors for construction (#862)
    
    * LUCENE-9848 Sort HNSW graph neighbors for construction
    
    Sort HNSW graph neighbors when applying diversity criterion
    
    During HNSW graph construction, when a node has already a number of
    connections larger than maximum allowed (maxConn), we need to prune
    its connections using a diversity criteria to limit the number of
    connections to maxConn.
    
    Currently when we add reverse connections to already existing nodes,
    we don't keep them sorted. Thus later, when we apply diversity criteria
    we may prune not the worst most distant non-diverse nodes.
    
    This patch makes sure that neighbours connections are always sorted
    from best (closest) to worst (distant), and during the application
    of diversity criteria processes nodes from worst to best.
    
    This path does the following:
    - enhance NeighborArray to always keep neighbour nodes sorted according
      to their scores (in desc or asc order). Make NeighborArray aware in
      which order the nodes should be sorted.
    - make OnHeapHnswGraph aware of the order of similarity function
    - make HnswGraphBuilder apply diversity criteria from worst to
      best nodes
    - create Lucene90NeighborArray to keep the previous logic of
      NeighborArray for Lucene90Codec
---
 .../lucene90/Lucene90HnswGraphBuilder.java         |  18 +--
 .../lucene90/Lucene90NeighborArray.java}           |  16 ++-
 .../lucene90/Lucene90OnHeapHnswGraph.java          |  11 +-
 .../lucene90/Lucene90HnswVectorsWriter.java        |   3 +-
 .../apache/lucene/util/hnsw/HnswGraphBuilder.java  |  64 ++++------
 .../org/apache/lucene/util/hnsw/NeighborArray.java |  68 ++++++++++-
 .../org/apache/lucene/util/hnsw/NeighborQueue.java |   2 +-
 .../apache/lucene/util/hnsw/OnHeapHnswGraph.java   |   8 +-
 .../apache/lucene/util/hnsw/KnnGraphTester.java    |   3 +-
 .../apache/lucene/util/hnsw/TestNeighborArray.java | 133 +++++++++++++++++++++
 10 files changed, 256 insertions(+), 70 deletions(-)

diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
index b59ba3b4a7a..0e8afd822b2 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
@@ -26,7 +26,6 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.util.InfoStream;
 import org.apache.lucene.util.hnsw.BoundsChecker;
-import org.apache.lucene.util.hnsw.NeighborArray;
 import org.apache.lucene.util.hnsw.NeighborQueue;
 
 /**
@@ -47,7 +46,7 @@ public final class Lucene90HnswGraphBuilder {
 
   private final int maxConn;
   private final int beamWidth;
-  private final NeighborArray scratch;
+  private final Lucene90NeighborArray scratch;
 
   private final VectorSimilarityFunction similarityFunction;
   private final RandomAccessVectorValues vectorValues;
@@ -93,7 +92,7 @@ public final class Lucene90HnswGraphBuilder {
     this.hnsw = new Lucene90OnHeapHnswGraph(maxConn);
     bound = BoundsChecker.create(similarityFunction.reversed);
     random = new SplittableRandom(seed);
-    scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
+    scratch = new Lucene90NeighborArray(Math.max(beamWidth, maxConn + 1));
   }
 
   /**
@@ -173,7 +172,7 @@ public final class Lucene90HnswGraphBuilder {
      * is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
      * since the node is new and has no prior neighbors).
      */
-    NeighborArray neighbors = hnsw.getNeighbors(node);
+    Lucene90NeighborArray neighbors = hnsw.getNeighbors(node);
     assert neighbors.size() == 0; // new node
     popToScratch(candidates);
     selectDiverse(neighbors, scratch);
@@ -183,7 +182,7 @@ public final class Lucene90HnswGraphBuilder {
     int size = neighbors.size();
     for (int i = 0; i < size; i++) {
       int nbr = neighbors.node()[i];
-      NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
+      Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
       nbrNbr.add(node, neighbors.score()[i]);
       if (nbrNbr.size() > maxConn) {
         diversityUpdate(nbrNbr);
@@ -191,7 +190,8 @@ public final class Lucene90HnswGraphBuilder {
     }
   }
 
-  private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException {
+  private void selectDiverse(Lucene90NeighborArray neighbors, Lucene90NeighborArray candidates)
+      throws IOException {
     // Select the best maxConn neighbors of the new node, applying the diversity heuristic
     for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
       // compare each neighbor (in distance order) against the closer neighbors selected so far,
@@ -228,7 +228,7 @@ public final class Lucene90HnswGraphBuilder {
   private boolean diversityCheck(
       float[] candidate,
       float score,
-      NeighborArray neighbors,
+      Lucene90NeighborArray neighbors,
       RandomAccessVectorValues vectorValues)
       throws IOException {
     bound.set(score);
@@ -242,7 +242,7 @@ public final class Lucene90HnswGraphBuilder {
     return true;
   }
 
-  private void diversityUpdate(NeighborArray neighbors) throws IOException {
+  private void diversityUpdate(Lucene90NeighborArray neighbors) throws IOException {
     assert neighbors.size() == maxConn + 1;
     int replacePoint = findNonDiverse(neighbors);
     if (replacePoint == -1) {
@@ -262,7 +262,7 @@ public final class Lucene90HnswGraphBuilder {
   }
 
   // scan neighbors looking for diversity violations
-  private int findNonDiverse(NeighborArray neighbors) throws IOException {
+  private int findNonDiverse(Lucene90NeighborArray neighbors) throws IOException {
     for (int i = neighbors.size() - 1; i >= 0; i--) {
       // check each neighbor against its better-scoring neighbors. If it fails diversity check with
       // them, drop it
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90NeighborArray.java
similarity index 79%
copy from lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
copy to lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90NeighborArray.java
index 40125750309..e2412fcd7da 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90NeighborArray.java
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.lucene.util.hnsw;
+package org.apache.lucene.backward_codecs.lucene90;
 
 import org.apache.lucene.util.ArrayUtil;
 
@@ -25,18 +25,20 @@ import org.apache.lucene.util.ArrayUtil;
  *
  * @lucene.internal
  */
-public class NeighborArray {
+public class Lucene90NeighborArray {
 
   private int size;
 
   float[] score;
   int[] node;
 
-  public NeighborArray(int maxSize) {
+  /** Create a neighbour array with the given initial size */
+  public Lucene90NeighborArray(int maxSize) {
     node = new int[maxSize];
     score = new float[maxSize];
   }
 
+  /** Add a new node with a score */
   public void add(int newNode, float newScore) {
     if (size == node.length - 1) {
       node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
@@ -47,6 +49,7 @@ public class NeighborArray {
     ++size;
   }
 
+  /** Get the size, the number of nodes added so far */
   public int size() {
     return size;
   }
@@ -60,14 +63,21 @@ public class NeighborArray {
     return node;
   }
 
+  /**
+   * Direct access to the internal list of scores
+   *
+   * @lucene.internal
+   */
   public float[] score() {
     return score;
   }
 
+  /** Clear all the nodes in the array */
   public void clear() {
     size = 0;
   }
 
+  /** Remove the last nodes from the array */
   public void removeLast() {
     size--;
   }
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
index 9de59301abb..6457b8071e9 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
@@ -29,7 +29,6 @@ import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.SparseFixedBitSet;
 import org.apache.lucene.util.hnsw.BoundsChecker;
 import org.apache.lucene.util.hnsw.HnswGraph;
-import org.apache.lucene.util.hnsw.NeighborArray;
 import org.apache.lucene.util.hnsw.NeighborQueue;
 
 /**
@@ -43,17 +42,17 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
   // Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to
   // HnswBuilder, and the
   // node values are the ordinals of those vectors.
-  private final List<NeighborArray> graph;
+  private final List<Lucene90NeighborArray> graph;
 
   // KnnGraphValues iterator members
   private int upto;
-  private NeighborArray cur;
+  private Lucene90NeighborArray cur;
 
   Lucene90OnHeapHnswGraph(int maxConn) {
     graph = new ArrayList<>();
     // Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
     // about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
-    graph.add(new NeighborArray(Math.max(32, maxConn / 4)));
+    graph.add(new Lucene90NeighborArray(Math.max(32, maxConn / 4)));
     this.maxConn = maxConn;
   }
 
@@ -162,7 +161,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
    *
    * @param node the node whose neighbors are returned
    */
-  public NeighborArray getNeighbors(int node) {
+  public Lucene90NeighborArray getNeighbors(int node) {
     return graph.get(node);
   }
 
@@ -172,7 +171,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
   }
 
   int addNode() {
-    graph.add(new NeighborArray(maxConn + 1));
+    graph.add(new Lucene90NeighborArray(maxConn + 1));
     return graph.size() - 1;
   }
 
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java
index a71f5efb14f..44e46ab9b16 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java
@@ -35,7 +35,6 @@ import org.apache.lucene.store.IndexOutput;
 import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.IOUtils;
-import org.apache.lucene.util.hnsw.NeighborArray;
 
 /**
  * Writes vector values and knn graphs to index segments.
@@ -247,7 +246,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
       // write graph
       offsets[ord] = graphData.getFilePointer() - graphDataOffset;
 
-      NeighborArray neighbors = graph.getNeighbors(ord);
+      Lucene90NeighborArray neighbors = graph.getNeighbors(ord);
       int size = neighbors.size();
 
       // Destructively modify; it's ok we are discarding it after this
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index dcd1c25a77f..63fe10fba4a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -95,14 +95,15 @@ public final class HnswGraphBuilder {
     this.ml = 1 / Math.log(1.0 * maxConn);
     this.random = new SplittableRandom(seed);
     int levelOfFirstNode = getRandomGraphLevel(ml, random);
-    this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode);
+    this.hnsw = new OnHeapHnswGraph(maxConn, levelOfFirstNode, similarityFunction.reversed);
     this.graphSearcher =
         new HnswGraphSearcher(
             similarityFunction,
             new NeighborQueue(beamWidth, similarityFunction.reversed == false),
             new FixedBitSet(vectorValues.size()));
     bound = BoundsChecker.create(similarityFunction.reversed);
-    scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
+    // in scratch we store candidates in reverse order: worse candidates are first
+    scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1), similarityFunction.reversed);
   }
 
   /**
@@ -176,11 +177,6 @@ public final class HnswGraphBuilder {
     return now;
   }
 
-  /* TODO: we are not maintaining nodes in strict score order; the forward links
-   * are added in sorted order, but the reverse implicit ones are not. Diversity heuristic should
-   * work better if we keep the neighbor arrays sorted. Possibly we should switch back to a heap?
-   * But first we should just see if sorting makes a significant difference.
-   */
   private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
       throws IOException {
     /* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
@@ -190,7 +186,7 @@ public final class HnswGraphBuilder {
     NeighborArray neighbors = hnsw.getNeighbors(level, node);
     assert neighbors.size() == 0; // new node
     popToScratch(candidates);
-    selectDiverse(neighbors, scratch);
+    selectAndLinkDiverse(neighbors, scratch);
 
     // Link the selected nodes to the new node, and the new node to the selected nodes (again
     // applying diversity heuristic)
@@ -198,14 +194,16 @@ public final class HnswGraphBuilder {
     for (int i = 0; i < size; i++) {
       int nbr = neighbors.node[i];
       NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
-      nbrNbr.add(node, neighbors.score[i]);
+      nbrNbr.insertSorted(node, neighbors.score[i]);
       if (nbrNbr.size() > maxConn) {
-        diversityUpdate(nbrNbr);
+        int indexToRemove = findWorstNonDiverse(nbrNbr);
+        nbrNbr.removeIndex(indexToRemove);
       }
     }
   }
 
-  private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException {
+  private void selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates)
+      throws IOException {
     // Select the best maxConn neighbors of the new node, applying the diversity heuristic
     for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
       // compare each neighbor (in distance order) against the closer neighbors selected so far,
@@ -256,44 +254,26 @@ public final class HnswGraphBuilder {
     return true;
   }
 
-  private void diversityUpdate(NeighborArray neighbors) throws IOException {
-    assert neighbors.size() == maxConn + 1;
-    int replacePoint = findNonDiverse(neighbors);
-    if (replacePoint == -1) {
-      // none found; check score against worst existing neighbor
-      bound.set(neighbors.score[0]);
-      if (bound.check(neighbors.score[maxConn])) {
-        // drop the new neighbor; it is not competitive and there were no diversity failures
-        neighbors.removeLast();
-        return;
-      } else {
-        replacePoint = 0;
-      }
-    }
-    neighbors.node[replacePoint] = neighbors.node[maxConn];
-    neighbors.score[replacePoint] = neighbors.score[maxConn];
-    neighbors.removeLast();
-  }
-
-  // scan neighbors looking for diversity violations
-  private int findNonDiverse(NeighborArray neighbors) throws IOException {
-    for (int i = neighbors.size() - 1; i >= 0; i--) {
-      // check each neighbor against its better-scoring neighbors. If it fails diversity check with
-      // them, drop it
-      int nbrNode = neighbors.node[i];
+  /**
+   * Find first non-diverse neighbour among the list of neighbors starting from the most distant
+   * neighbours
+   */
+  private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
+    for (int i = neighbors.size() - 1; i > 0; i--) {
+      int cNode = neighbors.node[i];
+      float[] cVector = vectorValues.vectorValue(cNode);
       bound.set(neighbors.score[i]);
-      float[] nbrVector = vectorValues.vectorValue(nbrNode);
-      for (int j = maxConn; j > i; j--) {
+      // check the candidate against its better-scoring neighbors
+      for (int j = i - 1; j >= 0; j--) {
         float diversityCheck =
-            similarityFunction.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
+            similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
+        // node i is too similar to node j given its score relative to the base node
         if (bound.check(diversityCheck) == false) {
-          // node j is too similar to node i given its score relative to the base node
-          // replace it with the new node, which is at [maxConn]
           return i;
         }
       }
     }
-    return -1;
+    return neighbors.size() - 1;
   }
 
   private static int getRandomGraphLevel(double ml, SplittableRandom random) {
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
index 40125750309..78224ed2358 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
@@ -17,36 +17,67 @@
 
 package org.apache.lucene.util.hnsw;
 
+import java.util.Arrays;
 import org.apache.lucene.util.ArrayUtil;
 
 /**
  * NeighborArray encodes the neighbors of a node and their mutual scores in the HNSW graph as a pair
- * of growable arrays.
+ * of growable arrays. Nodes are arranged in the sorted order of their scores in descending order
+ * (if scoresDescOrder is true), or in the ascending order of their scores (if scoresDescOrder is
+ * false)
  *
  * @lucene.internal
  */
 public class NeighborArray {
-
+  private final boolean scoresDescOrder;
   private int size;
 
   float[] score;
   int[] node;
 
-  public NeighborArray(int maxSize) {
+  public NeighborArray(int maxSize, boolean descOrder) {
     node = new int[maxSize];
     score = new float[maxSize];
+    this.scoresDescOrder = descOrder;
   }
 
+  /**
+   * Add a new node to the NeighborArray. The new node must be worse than all previously stored
+   * nodes.
+   */
   public void add(int newNode, float newScore) {
     if (size == node.length - 1) {
       node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
       score = ArrayUtil.growExact(score, node.length);
     }
+    if (size > 0) {
+      float previousScore = score[size - 1];
+      assert ((scoresDescOrder && (previousScore >= newScore))
+              || (scoresDescOrder == false && (previousScore <= newScore)))
+          : "Nodes are added in the incorrect order!";
+    }
     node[size] = newNode;
     score[size] = newScore;
     ++size;
   }
 
+  /** Add a new node to the NeighborArray into a correct sort position according to its score. */
+  public void insertSorted(int newNode, float newScore) {
+    if (size == node.length - 1) {
+      node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
+      score = ArrayUtil.growExact(score, node.length);
+    }
+    int insertionPoint =
+        scoresDescOrder
+            ? descSortFindRightMostInsertionPoint(newScore)
+            : ascSortFindRightMostInsertionPoint(newScore);
+    System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint);
+    System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint);
+    node[insertionPoint] = newNode;
+    score[insertionPoint] = newScore;
+    ++size;
+  }
+
   public int size() {
     return size;
   }
@@ -72,8 +103,39 @@ public class NeighborArray {
     size--;
   }
 
+  public void removeIndex(int idx) {
+    System.arraycopy(node, idx + 1, node, idx, size - idx);
+    System.arraycopy(score, idx + 1, score, idx, size - idx);
+    size--;
+  }
+
   @Override
   public String toString() {
     return "NeighborArray[" + size + "]";
   }
+
+  private int ascSortFindRightMostInsertionPoint(float newScore) {
+    int insertionPoint = Arrays.binarySearch(score, 0, size, newScore);
+    if (insertionPoint >= 0) {
+      // find the right most position with the same score
+      while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
+        insertionPoint++;
+      }
+      insertionPoint++;
+    } else {
+      insertionPoint = -insertionPoint - 1;
+    }
+    return insertionPoint;
+  }
+
+  private int descSortFindRightMostInsertionPoint(float newScore) {
+    int start = 0;
+    int end = size - 1;
+    while (start <= end) {
+      int mid = (start + end) / 2;
+      if (score[mid] < newScore) end = mid - 1;
+      else start = mid + 1;
+    }
+    return start;
+  }
 }
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
index 2d18cbb99ca..cb58c608f61 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java
@@ -29,7 +29,7 @@ import org.apache.lucene.util.NumericUtils;
  */
 public class NeighborQueue {
 
-  private static enum Order {
+  private enum Order {
     NATURAL {
       @Override
       long apply(long v) {
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index 09f8afa7aa7..08cecd1f8ac 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -31,6 +31,7 @@ import org.apache.lucene.util.ArrayUtil;
 public final class OnHeapHnswGraph extends HnswGraph {
 
   private final int maxConn;
+  private final boolean similarityReversed;
   private int numLevels; // the current number of levels in the graph
   private int entryNode; // the current graph entry node on the top level
 
@@ -49,8 +50,9 @@ public final class OnHeapHnswGraph extends HnswGraph {
   private int upto;
   private NeighborArray cur;
 
-  OnHeapHnswGraph(int maxConn, int levelOfFirstNode) {
+  OnHeapHnswGraph(int maxConn, int levelOfFirstNode, boolean similarityReversed) {
     this.maxConn = maxConn;
+    this.similarityReversed = similarityReversed;
     this.numLevels = levelOfFirstNode + 1;
     this.graph = new ArrayList<>(numLevels);
     this.entryNode = 0;
@@ -59,7 +61,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
       // Typically with diversity criteria we see nodes not fully occupied;
       // average fanout seems to be about 1/2 maxConn.
       // There is some indexing time penalty for under-allocating, but saves RAM
-      graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
+      graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4), similarityReversed == false));
     }
 
     this.nodesByLevel = new ArrayList<>(numLevels);
@@ -120,7 +122,7 @@ public final class OnHeapHnswGraph extends HnswGraph {
       }
     }
 
-    graph.get(level).add(new NeighborArray(maxConn + 1));
+    graph.get(level).add(new NeighborArray(maxConn + 1, similarityReversed == false));
   }
 
   @Override
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index 822ef78197c..57389d098c9 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -276,7 +276,8 @@ public class KnnGraphTester {
     for (int i = 0; i < hnsw.size(); i++) {
       NeighborArray neighbors = hnsw.getNeighbors(0, i);
       System.out.printf(Locale.ROOT, "%5d", i);
-      NeighborArray sorted = new NeighborArray(neighbors.size());
+      NeighborArray sorted =
+          new NeighborArray(neighbors.size(), similarityFunction.reversed == false);
       for (int j = 0; j < neighbors.size(); j++) {
         int node = neighbors.node[j];
         float score = neighbors.score[j];
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java
new file mode 100644
index 00000000000..b8ae24f6200
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.hnsw;
+
+import org.apache.lucene.tests.util.LuceneTestCase;
+
+public class TestNeighborArray extends LuceneTestCase {
+
+  public void testScoresDescOrder() {
+    NeighborArray neighbors = new NeighborArray(10, true);
+    neighbors.add(0, 1);
+    neighbors.add(1, 0.8f);
+
+    AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.9f));
+    assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
+
+    neighbors.insertSorted(3, 0.9f);
+    assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
+    asserNodesEqual(new int[] {0, 3, 1}, neighbors);
+
+    neighbors.insertSorted(4, 1f);
+    assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors);
+    asserNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
+
+    neighbors.insertSorted(5, 1.1f);
+    assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors);
+    asserNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
+
+    neighbors.insertSorted(6, 0.8f);
+    assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors);
+    asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
+
+    neighbors.insertSorted(7, 0.8f);
+    assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
+    asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
+
+    neighbors.removeIndex(2);
+    assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
+    asserNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
+
+    neighbors.removeIndex(0);
+    assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
+    asserNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
+
+    neighbors.removeIndex(4);
+    assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors);
+    asserNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
+
+    neighbors.removeLast();
+    assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
+    asserNodesEqual(new int[] {0, 3, 1}, neighbors);
+
+    neighbors.insertSorted(8, 0.9f);
+    assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors);
+    asserNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
+  }
+
+  public void testScoresAscOrder() {
+    NeighborArray neighbors = new NeighborArray(10, false);
+    neighbors.add(0, 0.1f);
+    neighbors.add(1, 0.3f);
+
+    AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.15f));
+    assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
+
+    neighbors.insertSorted(3, 0.3f);
+    assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);
+    asserNodesEqual(new int[] {0, 1, 3}, neighbors);
+
+    neighbors.insertSorted(4, 0.2f);
+    assertScoresEqual(new float[] {0.1f, 0.2f, 0.3f, 0.3f}, neighbors);
+    asserNodesEqual(new int[] {0, 4, 1, 3}, neighbors);
+
+    neighbors.insertSorted(5, 0.05f);
+    assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.3f, 0.3f}, neighbors);
+    asserNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors);
+
+    neighbors.insertSorted(6, 0.2f);
+    assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
+    asserNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors);
+
+    neighbors.insertSorted(7, 0.2f);
+    assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
+    asserNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors);
+
+    neighbors.removeIndex(2);
+    assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
+    asserNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors);
+
+    neighbors.removeIndex(0);
+    assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
+    asserNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors);
+
+    neighbors.removeIndex(4);
+    assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f}, neighbors);
+    asserNodesEqual(new int[] {0, 6, 7, 1}, neighbors);
+
+    neighbors.removeLast();
+    assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f}, neighbors);
+    asserNodesEqual(new int[] {0, 6, 7}, neighbors);
+
+    neighbors.insertSorted(8, 0.01f);
+    assertScoresEqual(new float[] {0.01f, 0.1f, 0.2f, 0.2f}, neighbors);
+    asserNodesEqual(new int[] {8, 0, 6, 7}, neighbors);
+  }
+
+  private void assertScoresEqual(float[] scores, NeighborArray neighbors) {
+    for (int i = 0; i < scores.length; i++) {
+      assertEquals(scores[i], neighbors.score[i], 0.01f);
+    }
+  }
+
+  private void asserNodesEqual(int[] nodes, NeighborArray neighbors) {
+    for (int i = 0; i < nodes.length; i++) {
+      assertEquals(nodes[i], neighbors.node[i]);
+    }
+  }
+}