You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by pa...@apache.org on 2023/05/16 06:20:37 UTC

[lucene] branch main updated: Optimize HNSW diversity calculation (#12235)

This is an automated email from the ASF dual-hosted git repository.

patrickz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new 8af305892d7 Optimize HNSW diversity calculation (#12235)
8af305892d7 is described below

commit 8af305892d726c180f03316c73aebf8183c2e481
Author: Patrick Zhai <zh...@users.noreply.github.com>
AuthorDate: Mon May 15 23:20:31 2023 -0700

    Optimize HNSW diversity calculation (#12235)
---
 lucene/CHANGES.txt                                 |   2 +
 .../synonym/word2vec/Word2VecSynonymProvider.java  |   2 +-
 .../lucene90/Lucene90HnswGraphBuilder.java         |   8 +-
 .../lucene91/Lucene91HnswGraphBuilder.java         |   8 +-
 .../apache/lucene/util/hnsw/HnswGraphBuilder.java  | 124 ++++++++++++++++-----
 .../org/apache/lucene/util/hnsw/NeighborArray.java |  88 ++++++++++++---
 .../apache/lucene/util/hnsw/OnHeapHnswGraph.java   |  12 +-
 .../apache/lucene/util/hnsw/TestNeighborArray.java | 114 ++++++++++++++-----
 8 files changed, 272 insertions(+), 86 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 7876d4bfbf5..308e28f27cd 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -128,6 +128,8 @@ Optimizations
 
 * GITHUB#12286 Toposort use iterator to avoid stackoverflow. (Tang Donghai)
 
+* GITHUB#12235: Optimize HNSW diversity calculation. (Patrick Zhai)
+
 Bug Fixes
 ---------------------
 (No changes)
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
index 1d1f06fc991..3784d9372fe 100644
--- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
@@ -85,7 +85,7 @@ public class Word2VecSynonymProvider {
               SIMILARITY_FUNCTION,
               hnswGraph,
               null,
-              word2VecModel.size());
+              Integer.MAX_VALUE);
 
       int size = synonyms.size();
       for (int i = 0; i < size; i++) {
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
index e7f16b4f3fc..4b1f7068a5f 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
@@ -183,10 +183,10 @@ public final class Lucene90HnswGraphBuilder {
     int size = neighbors.size();
     for (int i = 0; i < size; i++) {
       int nbr = neighbors.node()[i];
-      Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
-      nbrNbr.add(node, neighbors.score()[i]);
-      if (nbrNbr.size() > maxConn) {
-        diversityUpdate(nbrNbr);
+      Lucene90NeighborArray nbrsOfNbr = hnsw.getNeighbors(nbr);
+      nbrsOfNbr.add(node, neighbors.score()[i]);
+      if (nbrsOfNbr.size() > maxConn) {
+        diversityUpdate(nbrsOfNbr);
       }
     }
   }
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
index b0e9d160457..c82920181cc 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
@@ -204,10 +204,10 @@ public final class Lucene91HnswGraphBuilder {
     int size = neighbors.size();
     for (int i = 0; i < size; i++) {
       int nbr = neighbors.node[i];
-      Lucene91NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
-      nbrNbr.add(node, neighbors.score[i]);
-      if (nbrNbr.size() > maxConn) {
-        diversityUpdate(nbrNbr);
+      Lucene91NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
+      nbrsOfNbr.add(node, neighbors.score[i]);
+      if (nbrsOfNbr.size() > maxConn) {
+        diversityUpdate(nbrsOfNbr);
       }
     }
   }
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index edefb696b8c..e2a57a303c6 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -218,7 +218,9 @@ public final class HnswGraphBuilder<T> {
                 case BYTE -> this.similarityFunction.compare(
                     binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor));
               };
-          newNeighbors.insertSorted(newNeighbor, score);
+          // we are not sure whether the previous graph contains
+          // unchecked nodes, so we have to assume they're all unchecked
+          newNeighbors.addOutOfOrder(newNeighbor, score);
         }
       }
     }
@@ -314,11 +316,11 @@ public final class HnswGraphBuilder<T> {
     int size = neighbors.size();
     for (int i = 0; i < size; i++) {
       int nbr = neighbors.node[i];
-      NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
-      nbrNbr.insertSorted(node, neighbors.score[i]);
-      if (nbrNbr.size() > maxConnOnLevel) {
-        int indexToRemove = findWorstNonDiverse(nbrNbr);
-        nbrNbr.removeIndex(indexToRemove);
+      NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
+      nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
+      if (nbrsOfNbr.size() > maxConnOnLevel) {
+        int indexToRemove = findWorstNonDiverse(nbrsOfNbr);
+        nbrsOfNbr.removeIndex(indexToRemove);
       }
     }
   }
@@ -333,7 +335,7 @@ public final class HnswGraphBuilder<T> {
       float cScore = candidates.score[i];
       assert cNode < hnsw.size();
       if (diversityCheck(cNode, cScore, neighbors)) {
-        neighbors.add(cNode, cScore);
+        neighbors.addInOrder(cNode, cScore);
       }
     }
   }
@@ -345,7 +347,7 @@ public final class HnswGraphBuilder<T> {
     // sorted from worst to best
     for (int i = 0; i < candidateCount; i++) {
       float maxSimilarity = candidates.topScore();
-      scratch.add(candidates.pop(), maxSimilarity);
+      scratch.addInOrder(candidates.pop(), maxSimilarity);
     }
   }
 
@@ -400,50 +402,116 @@ public final class HnswGraphBuilder<T> {
    * neighbours
    */
   private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
+    int[] uncheckedIndexes = neighbors.sort();
+    if (uncheckedIndexes == null) {
+      // all nodes are checked, we will directly return the most distant one
+      return neighbors.size() - 1;
+    }
+    int uncheckedCursor = uncheckedIndexes.length - 1;
     for (int i = neighbors.size() - 1; i > 0; i--) {
-      if (isWorstNonDiverse(i, neighbors)) {
+      if (uncheckedCursor < 0) {
+        // no unchecked node left
+        break;
+      }
+      if (isWorstNonDiverse(i, neighbors, uncheckedIndexes, uncheckedCursor)) {
         return i;
       }
+      if (i == uncheckedIndexes[uncheckedCursor]) {
+        uncheckedCursor--;
+      }
     }
     return neighbors.size() - 1;
   }
 
-  private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
+  private boolean isWorstNonDiverse(
+      int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor)
       throws IOException {
     int candidateNode = neighbors.node[candidateIndex];
     return switch (vectorEncoding) {
       case BYTE -> isWorstNonDiverse(
-          candidateIndex, (byte[]) vectors.vectorValue(candidateNode), neighbors);
+          candidateIndex,
+          (byte[]) vectors.vectorValue(candidateNode),
+          neighbors,
+          uncheckedIndexes,
+          uncheckedCursor);
       case FLOAT32 -> isWorstNonDiverse(
-          candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
+          candidateIndex,
+          (float[]) vectors.vectorValue(candidateNode),
+          neighbors,
+          uncheckedIndexes,
+          uncheckedCursor);
     };
   }
 
   private boolean isWorstNonDiverse(
-      int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
+      int candidateIndex,
+      float[] candidateVector,
+      NeighborArray neighbors,
+      int[] uncheckedIndexes,
+      int uncheckedCursor)
+      throws IOException {
     float minAcceptedSimilarity = neighbors.score[candidateIndex];
-    for (int i = candidateIndex - 1; i >= 0; i--) {
-      float neighborSimilarity =
-          similarityFunction.compare(
-              candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
-      // candidate node is too similar to node i given its score relative to the base node
-      if (neighborSimilarity >= minAcceptedSimilarity) {
-        return true;
+    if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
+      // the candidate itself is unchecked
+      for (int i = candidateIndex - 1; i >= 0; i--) {
+        float neighborSimilarity =
+            similarityFunction.compare(
+                candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
+        // candidate node is too similar to node i given its score relative to the base node
+        if (neighborSimilarity >= minAcceptedSimilarity) {
+          return true;
+        }
+      }
+    } else {
+      // else we just need to make sure candidate does not violate diversity with the (newly
+      // inserted) unchecked nodes
+      assert candidateIndex > uncheckedIndexes[uncheckedCursor];
+      for (int i = uncheckedCursor; i >= 0; i--) {
+        float neighborSimilarity =
+            similarityFunction.compare(
+                candidateVector,
+                (float[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
+        // candidate node is too similar to node i given its score relative to the base node
+        if (neighborSimilarity >= minAcceptedSimilarity) {
+          return true;
+        }
       }
     }
     return false;
   }
 
   private boolean isWorstNonDiverse(
-      int candidateIndex, byte[] candidateVector, NeighborArray neighbors) throws IOException {
+      int candidateIndex,
+      byte[] candidateVector,
+      NeighborArray neighbors,
+      int[] uncheckedIndexes,
+      int uncheckedCursor)
+      throws IOException {
     float minAcceptedSimilarity = neighbors.score[candidateIndex];
-    for (int i = candidateIndex - 1; i >= 0; i--) {
-      float neighborSimilarity =
-          similarityFunction.compare(
-              candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
-      // candidate node is too similar to node i given its score relative to the base node
-      if (neighborSimilarity >= minAcceptedSimilarity) {
-        return true;
+    if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
+      // the candidate itself is unchecked
+      for (int i = candidateIndex - 1; i >= 0; i--) {
+        float neighborSimilarity =
+            similarityFunction.compare(
+                candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
+        // candidate node is too similar to node i given its score relative to the base node
+        if (neighborSimilarity >= minAcceptedSimilarity) {
+          return true;
+        }
+      }
+    } else {
+      // else we just need to make sure candidate does not violate diversity with the (newly
+      // inserted) unchecked nodes
+      assert candidateIndex > uncheckedIndexes[uncheckedCursor];
+      for (int i = uncheckedCursor; i >= 0; i--) {
+        float neighborSimilarity =
+            similarityFunction.compare(
+                candidateVector,
+                (byte[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
+        // candidate node is too similar to node i given its score relative to the base node
+        if (neighborSimilarity >= minAcceptedSimilarity) {
+          return true;
+        }
       }
     }
     return false;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
index ec1b5ec3e89..a23b9b5254e 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
@@ -34,6 +34,7 @@ public class NeighborArray {
 
   float[] score;
   int[] node;
+  private int sortedNodeSize;
 
   public NeighborArray(int maxSize, boolean descOrder) {
     node = new int[maxSize];
@@ -43,9 +44,10 @@ public class NeighborArray {
 
   /**
    * Add a new node to the NeighborArray. The new node must be worse than all previously stored
-   * nodes.
+   * nodes. This cannot be called after {@link #addOutOfOrder(int, float)}
    */
-  public void add(int newNode, float newScore) {
+  public void addInOrder(int newNode, float newScore) {
+    assert size == sortedNodeSize : "cannot call addInOrder after addOutOfOrder";
     if (size == node.length) {
       node = ArrayUtil.grow(node);
       score = ArrayUtil.growExact(score, node.length);
@@ -59,23 +61,72 @@ public class NeighborArray {
     node[size] = newNode;
     score[size] = newScore;
     ++size;
+    ++sortedNodeSize;
   }
 
-  /** Add a new node to the NeighborArray into a correct sort position according to its score. */
-  public void insertSorted(int newNode, float newScore) {
+  /** Add node and score but do not insert as sorted */
+  public void addOutOfOrder(int newNode, float newScore) {
     if (size == node.length) {
       node = ArrayUtil.grow(node);
       score = ArrayUtil.growExact(score, node.length);
     }
+    node[size] = newNode;
+    score[size] = newScore;
+    size++;
+  }
+
+  /**
+   * Sort the array according to scores, and return the sorted indexes of previous unsorted nodes
+   * (unchecked nodes)
+   *
+   * @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is
+   *     already fully sorted
+   */
+  public int[] sort() {
+    if (size == sortedNodeSize) {
+      // all nodes checked and sorted
+      return null;
+    }
+    assert sortedNodeSize < size;
+    int[] uncheckedIndexes = new int[size - sortedNodeSize];
+    int count = 0;
+    while (sortedNodeSize != size) {
+      uncheckedIndexes[count] = insertSortedInternal(); // sortedNodeSize is increased inside
+      for (int i = 0; i < count; i++) {
+        if (uncheckedIndexes[i] >= uncheckedIndexes[count]) {
+          // the previous inserted nodes has been shifted
+          uncheckedIndexes[i]++;
+        }
+      }
+      count++;
+    }
+    Arrays.sort(uncheckedIndexes);
+    return uncheckedIndexes;
+  }
+
+  /** insert the first unsorted node into its sorted position */
+  private int insertSortedInternal() {
+    assert sortedNodeSize < size : "Call this method only when there's unsorted node";
+    int tmpNode = node[sortedNodeSize];
+    float tmpScore = score[sortedNodeSize];
     int insertionPoint =
         scoresDescOrder
-            ? descSortFindRightMostInsertionPoint(newScore)
-            : ascSortFindRightMostInsertionPoint(newScore);
-    System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint);
-    System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint);
-    node[insertionPoint] = newNode;
-    score[insertionPoint] = newScore;
-    ++size;
+            ? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize)
+            : ascSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize);
+    System.arraycopy(
+        node, insertionPoint, node, insertionPoint + 1, sortedNodeSize - insertionPoint);
+    System.arraycopy(
+        score, insertionPoint, score, insertionPoint + 1, sortedNodeSize - insertionPoint);
+    node[insertionPoint] = tmpNode;
+    score[insertionPoint] = tmpScore;
+    ++sortedNodeSize;
+    return insertionPoint;
+  }
+
+  /** This method is for test only. */
+  void insertSorted(int newNode, float newScore) {
+    addOutOfOrder(newNode, newScore);
+    insertSortedInternal();
   }
 
   public int size() {
@@ -97,15 +148,20 @@ public class NeighborArray {
 
   public void clear() {
     size = 0;
+    sortedNodeSize = 0;
   }
 
   public void removeLast() {
     size--;
+    sortedNodeSize = Math.min(sortedNodeSize, size);
   }
 
   public void removeIndex(int idx) {
     System.arraycopy(node, idx + 1, node, idx, size - idx - 1);
     System.arraycopy(score, idx + 1, score, idx, size - idx - 1);
+    if (idx < sortedNodeSize) {
+      sortedNodeSize--;
+    }
     size--;
   }
 
@@ -114,11 +170,11 @@ public class NeighborArray {
     return "NeighborArray[" + size + "]";
   }
 
-  private int ascSortFindRightMostInsertionPoint(float newScore) {
-    int insertionPoint = Arrays.binarySearch(score, 0, size, newScore);
+  private int ascSortFindRightMostInsertionPoint(float newScore, int bound) {
+    int insertionPoint = Arrays.binarySearch(score, 0, bound, newScore);
     if (insertionPoint >= 0) {
       // find the right most position with the same score
-      while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
+      while ((insertionPoint < bound - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
         insertionPoint++;
       }
       insertionPoint++;
@@ -128,9 +184,9 @@ public class NeighborArray {
     return insertionPoint;
   }
 
-  private int descSortFindRightMostInsertionPoint(float newScore) {
+  private int descSortFindRightMostInsertionPoint(float newScore, int bound) {
     int start = 0;
-    int end = size - 1;
+    int end = bound - 1;
     while (start <= end) {
       int mid = (start + end) / 2;
       if (score[mid] < newScore) end = mid - 1;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index ae39614f160..b7f0ecfd075 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -171,14 +171,14 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
   public long ramBytesUsed() {
     long neighborArrayBytes0 =
         nsize0 * (Integer.BYTES + Float.BYTES)
-            + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
-            + RamUsageEstimator.NUM_BYTES_OBJECT_REF
-            + Integer.BYTES * 2;
+            + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+            + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
+            + Integer.BYTES * 3;
     long neighborArrayBytes =
         nsize * (Integer.BYTES + Float.BYTES)
-            + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
-            + RamUsageEstimator.NUM_BYTES_OBJECT_REF
-            + Integer.BYTES * 2;
+            + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+            + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
+            + Integer.BYTES * 3;
     long total = 0;
     for (int l = 0; l < numLevels; l++) {
       if (l == 0) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java
index b8ae24f6200..039f69c9dc4 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java
@@ -23,100 +23,160 @@ public class TestNeighborArray extends LuceneTestCase {
 
   public void testScoresDescOrder() {
     NeighborArray neighbors = new NeighborArray(10, true);
-    neighbors.add(0, 1);
-    neighbors.add(1, 0.8f);
+    neighbors.addInOrder(0, 1);
+    neighbors.addInOrder(1, 0.8f);
 
-    AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.9f));
+    AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f));
     assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
 
     neighbors.insertSorted(3, 0.9f);
     assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
-    asserNodesEqual(new int[] {0, 3, 1}, neighbors);
+    assertNodesEqual(new int[] {0, 3, 1}, neighbors);
 
     neighbors.insertSorted(4, 1f);
     assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors);
-    asserNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
+    assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors);
 
     neighbors.insertSorted(5, 1.1f);
     assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors);
-    asserNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
+    assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors);
 
     neighbors.insertSorted(6, 0.8f);
     assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors);
-    asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
+    assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors);
 
     neighbors.insertSorted(7, 0.8f);
     assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
-    asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
+    assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors);
 
     neighbors.removeIndex(2);
     assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
-    asserNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
+    assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors);
 
     neighbors.removeIndex(0);
     assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors);
-    asserNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
+    assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors);
 
     neighbors.removeIndex(4);
     assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors);
-    asserNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
+    assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors);
 
     neighbors.removeLast();
     assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors);
-    asserNodesEqual(new int[] {0, 3, 1}, neighbors);
+    assertNodesEqual(new int[] {0, 3, 1}, neighbors);
 
     neighbors.insertSorted(8, 0.9f);
     assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors);
-    asserNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
+    assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors);
   }
 
   public void testScoresAscOrder() {
     NeighborArray neighbors = new NeighborArray(10, false);
-    neighbors.add(0, 0.1f);
-    neighbors.add(1, 0.3f);
+    neighbors.addInOrder(0, 0.1f);
+    neighbors.addInOrder(1, 0.3f);
 
-    AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.15f));
+    AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f));
     assertEquals("Nodes are added in the incorrect order!", ex.getMessage());
 
     neighbors.insertSorted(3, 0.3f);
     assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors);
-    asserNodesEqual(new int[] {0, 1, 3}, neighbors);
+    assertNodesEqual(new int[] {0, 1, 3}, neighbors);
 
     neighbors.insertSorted(4, 0.2f);
     assertScoresEqual(new float[] {0.1f, 0.2f, 0.3f, 0.3f}, neighbors);
-    asserNodesEqual(new int[] {0, 4, 1, 3}, neighbors);
+    assertNodesEqual(new int[] {0, 4, 1, 3}, neighbors);
 
     neighbors.insertSorted(5, 0.05f);
     assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.3f, 0.3f}, neighbors);
-    asserNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors);
+    assertNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors);
 
     neighbors.insertSorted(6, 0.2f);
     assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
-    asserNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors);
+    assertNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors);
 
     neighbors.insertSorted(7, 0.2f);
     assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
-    asserNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors);
+    assertNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors);
 
     neighbors.removeIndex(2);
     assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
-    asserNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors);
+    assertNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors);
 
     neighbors.removeIndex(0);
     assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors);
-    asserNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors);
+    assertNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors);
 
     neighbors.removeIndex(4);
     assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f}, neighbors);
-    asserNodesEqual(new int[] {0, 6, 7, 1}, neighbors);
+    assertNodesEqual(new int[] {0, 6, 7, 1}, neighbors);
 
     neighbors.removeLast();
     assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f}, neighbors);
-    asserNodesEqual(new int[] {0, 6, 7}, neighbors);
+    assertNodesEqual(new int[] {0, 6, 7}, neighbors);
 
     neighbors.insertSorted(8, 0.01f);
     assertScoresEqual(new float[] {0.01f, 0.1f, 0.2f, 0.2f}, neighbors);
-    asserNodesEqual(new int[] {8, 0, 6, 7}, neighbors);
+    assertNodesEqual(new int[] {8, 0, 6, 7}, neighbors);
+  }
+
+  public void testSortAsc() {
+    NeighborArray neighbors = new NeighborArray(10, false);
+    neighbors.addOutOfOrder(1, 2);
+    // we disallow calling addInOrder after addOutOfOrder even if they're actual in order
+    expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2));
+    neighbors.addOutOfOrder(2, 3);
+    neighbors.addOutOfOrder(5, 6);
+    neighbors.addOutOfOrder(3, 4);
+    neighbors.addOutOfOrder(7, 8);
+    neighbors.addOutOfOrder(6, 7);
+    neighbors.addOutOfOrder(4, 5);
+    int[] unchecked = neighbors.sort();
+    assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
+    assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
+    assertScoresEqual(new float[] {2, 3, 4, 5, 6, 7, 8}, neighbors);
+
+    NeighborArray neighbors2 = new NeighborArray(10, false);
+    neighbors2.addInOrder(0, 1);
+    neighbors2.addInOrder(1, 2);
+    neighbors2.addInOrder(4, 5);
+    neighbors2.addOutOfOrder(2, 3);
+    neighbors2.addOutOfOrder(6, 7);
+    neighbors2.addOutOfOrder(5, 6);
+    neighbors2.addOutOfOrder(3, 4);
+    unchecked = neighbors2.sort();
+    assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked);
+    assertNodesEqual(new int[] {0, 1, 2, 3, 4, 5, 6}, neighbors2);
+    assertScoresEqual(new float[] {1, 2, 3, 4, 5, 6, 7}, neighbors2);
+  }
+
+  public void testSortDesc() {
+    NeighborArray neighbors = new NeighborArray(10, true);
+    neighbors.addOutOfOrder(1, 7);
+    // we disallow calling addInOrder after addOutOfOrder even if they're actual in order
+    expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2));
+    neighbors.addOutOfOrder(2, 6);
+    neighbors.addOutOfOrder(5, 3);
+    neighbors.addOutOfOrder(3, 5);
+    neighbors.addOutOfOrder(7, 1);
+    neighbors.addOutOfOrder(6, 2);
+    neighbors.addOutOfOrder(4, 4);
+    int[] unchecked = neighbors.sort();
+    assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
+    assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
+    assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
+
+    NeighborArray neighbors2 = new NeighborArray(10, true);
+    neighbors2.addInOrder(1, 7);
+    neighbors2.addInOrder(2, 6);
+    neighbors2.addInOrder(5, 3);
+    neighbors2.addOutOfOrder(3, 5);
+    neighbors2.addOutOfOrder(7, 1);
+    neighbors2.addOutOfOrder(6, 2);
+    neighbors2.addOutOfOrder(4, 4);
+    unchecked = neighbors2.sort();
+    assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked);
+    assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors2);
+    assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors2);
   }
 
   private void assertScoresEqual(float[] scores, NeighborArray neighbors) {
@@ -125,7 +185,7 @@ public class TestNeighborArray extends LuceneTestCase {
     }
   }
 
-  private void asserNodesEqual(int[] nodes, NeighborArray neighbors) {
+  private void assertNodesEqual(int[] nodes, NeighborArray neighbors) {
     for (int i = 0; i < nodes.length; i++) {
       assertEquals(nodes[i], neighbors.node[i]);
     }