You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by be...@apache.org on 2023/02/07 20:12:13 UTC

[lucene] branch branch_9x updated: Reuse HNSW graph for intialization during merge (#12050)

This is an automated email from the ASF dual-hosted git repository.

benwtrent pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new 133ec845a46 Reuse HNSW graph for intialization during merge (#12050)
133ec845a46 is described below

commit 133ec845a46267d081b54517350f597fc8cad858
Author: John Mazanec <jm...@amazon.com>
AuthorDate: Tue Feb 7 11:42:03 2023 -0800

    Reuse HNSW graph for intialization during merge (#12050)
    
    * Remove implicit addition of vector 0
    
    Removes logic to add 0 vector implicitly. This is in preparation for
    adding nodes from other graphs to initialize a new graph. Having the
    implicit addition of node 0 complicates this logic.
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
    
    * Enable out of order insertion of nodes in hnsw
    
    Enables nodes to be added into OnHeapHnswGraph in out of order fashion.
    To do so, additional operations have to be taken to resort the
    nodesByLevel array. Optimizations have been made to avoid sorting
    whenever possible.
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
    
    * Add ability to initialize from graph
    
    Adds method to initialize an HNSWGraphBuilder from another HNSWGraph.
    Initialization can only happen when the builder's graph is empty.
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
    
    * Utilize merge with graph init in HNSWWriter
    
    Uses HNSWGraphBuilder initialization from graph functionality in
    Lucene95HnswVectorsWriter. Selects the largest graph to initialize the
    new graph produced by the HNSWGraphBuilder for merge.
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
    
    * Minor modifications to Lucene95HnswVectorsWriter
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
    
    * Use TreeMap for graph structure for levels > 0
    
    Refactors OnHeapHnswGraph to use TreeMap to represent graph structure of
    levels greater than 0. Refactors NodesIterator to support set
    representation of nodes.
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
    
    * Refactor initializer to be in static create method
    
    Refeactors initialization from graph to be accessible via a create
    static method in HnswGraphBuilder.
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
    
    * Address review comments
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
    
    * Add change log entry
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
    
    * Remove empty iterator for neighborqueue
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
    
    ---------
    
    Signed-off-by: John Mazanec <jm...@amazon.com>
---
 lucene/CHANGES.txt                                 |   2 +
 .../lucene91/Lucene91HnswVectorsReader.java        |   4 +-
 .../lucene91/Lucene91OnHeapHnswGraph.java          |   4 +-
 .../lucene92/Lucene92HnswVectorsReader.java        |   4 +-
 .../lucene94/Lucene94HnswVectorsReader.java        |   4 +-
 .../lucene94/Lucene94HnswVectorsWriter.java        |   7 +-
 .../codecs/lucene95/Lucene95HnswVectorsReader.java |   4 +-
 .../codecs/lucene95/Lucene95HnswVectorsWriter.java | 225 +++++++++++++++++--
 .../org/apache/lucene/util/hnsw/HnswGraph.java     |  90 ++++++--
 .../apache/lucene/util/hnsw/HnswGraphBuilder.java  | 102 ++++++++-
 .../apache/lucene/util/hnsw/HnswGraphSearcher.java |   7 +-
 .../apache/lucene/util/hnsw/OnHeapHnswGraph.java   | 109 ++++-----
 .../test/org/apache/lucene/index/TestKnnGraph.java |  84 -------
 .../apache/lucene/util/hnsw/HnswGraphTestCase.java | 250 ++++++++++++++++++++-
 .../lucene/util/hnsw/TestHnswByteVectorGraph.java  |  28 +++
 .../lucene/util/hnsw/TestHnswFloatVectorGraph.java |  29 +++
 16 files changed, 753 insertions(+), 200 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 51fe306dbad..8b5026ff691 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -34,6 +34,8 @@ Optimizations
 
 * GITHUB#12128, GITHUB#12133: Speed up docvalues set query by making use of sortedness. (Robert Muir, Uwe Schindler)
 
+* GITHUB#12050: Reuse HNSW graph for intialization during merge (Jack Mazanec)
+
 Bug Fixes
 ---------------------
 (No changes)
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
index 80757388b3b..462f42c7e52 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
@@ -561,9 +561,9 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
     @Override
     public NodesIterator getNodesOnLevel(int level) {
       if (level == 0) {
-        return new NodesIterator(size());
+        return new ArrayNodesIterator(size());
       } else {
-        return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
+        return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
       }
     }
   }
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java
index 2d3ef582b47..e762e016bbf 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java
@@ -163,9 +163,9 @@ public final class Lucene91OnHeapHnswGraph extends HnswGraph {
   @Override
   public NodesIterator getNodesOnLevel(int level) {
     if (level == 0) {
-      return new NodesIterator(size());
+      return new ArrayNodesIterator(size());
     } else {
-      return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
+      return new ArrayNodesIterator(nodesByLevel.get(level), graph.get(level).size());
     }
   }
 }
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
index 0802d1c00ec..0f6eddfffed 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
@@ -457,9 +457,9 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
     @Override
     public NodesIterator getNodesOnLevel(int level) {
       if (level == 0) {
-        return new NodesIterator(size());
+        return new ArrayNodesIterator(size());
       } else {
-        return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
+        return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
       }
     }
   }
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
index 38f750621d8..78168171512 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
@@ -541,9 +541,9 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
     @Override
     public NodesIterator getNodesOnLevel(int level) {
       if (level == 0) {
-        return new NodesIterator(size());
+        return new ArrayNodesIterator(size());
       } else {
-        return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
+        return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
       }
     }
   }
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java
index 678ee74bab5..c92482bbe20 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java
@@ -355,7 +355,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
         if (level == 0) {
           return graph.getNodesOnLevel(0);
         } else {
-          return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
+          return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
         }
       }
     };
@@ -711,10 +711,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
       assert docID > lastDocID;
       docsWithField.add(docID);
       vectors.add(copyValue(vectorValue));
-      if (node > 0) {
-        // start at node 1! node 0 is added implicitly, in the constructor
-        hnswGraphBuilder.addGraphNode(node, vectorValue);
-      }
+      hnswGraphBuilder.addGraphNode(node, vectorValue);
       node++;
       lastDocID = docID;
     }
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java
index 78f0e47db67..8e724b3f9c9 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java
@@ -581,9 +581,9 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
     @Override
     public NodesIterator getNodesOnLevel(int level) {
       if (level == 0) {
-        return new NodesIterator(size());
+        return new ArrayNodesIterator(size());
       } else {
-        return new NodesIterator(nodesByLevel[level], nodesByLevel[level].length);
+        return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
       }
     }
   }
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java
index d3de5b80421..265a1033dc6 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java
@@ -25,11 +25,16 @@ import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import org.apache.lucene.codecs.CodecUtil;
 import org.apache.lucene.codecs.KnnFieldVectorsWriter;
+import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.codecs.KnnVectorsWriter;
 import org.apache.lucene.codecs.lucene90.IndexedDISI;
+import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
 import org.apache.lucene.index.*;
 import org.apache.lucene.index.Sorter;
 import org.apache.lucene.search.DocIdSetIterator;
@@ -369,7 +374,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
         if (level == 0) {
           return graph.getNodesOnLevel(0);
         } else {
-          return new NodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
+          return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
         }
       }
     };
@@ -444,6 +449,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
       OnHeapHnswGraph graph = null;
       int[][] vectorIndexNodeOffsets = null;
       if (docsWithField.cardinality() != 0) {
+        int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo);
         // build graph
         switch (fieldInfo.getVectorEncoding()) {
           case BYTE:
@@ -454,13 +460,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
                     vectorDataInput,
                     byteSize);
             HnswGraphBuilder<byte[]> bytesRefHnswGraphBuilder =
-                HnswGraphBuilder.create(
-                    byteVectorValues,
-                    fieldInfo.getVectorEncoding(),
-                    fieldInfo.getVectorSimilarityFunction(),
-                    M,
-                    beamWidth,
-                    HnswGraphBuilder.randSeed);
+                createHnswGraphBuilder(mergeState, fieldInfo, byteVectorValues, initializerIndex);
             bytesRefHnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
             graph = bytesRefHnswGraphBuilder.build(byteVectorValues.copy());
             break;
@@ -472,13 +472,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
                     vectorDataInput,
                     byteSize);
             HnswGraphBuilder<float[]> hnswGraphBuilder =
-                HnswGraphBuilder.create(
-                    vectorValues,
-                    fieldInfo.getVectorEncoding(),
-                    fieldInfo.getVectorSimilarityFunction(),
-                    M,
-                    beamWidth,
-                    HnswGraphBuilder.randSeed);
+                createHnswGraphBuilder(mergeState, fieldInfo, vectorValues, initializerIndex);
             hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
             graph = hnswGraphBuilder.build(vectorValues.copy());
             break;
@@ -512,6 +506,202 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
     }
   }
 
+  private <T> HnswGraphBuilder<T> createHnswGraphBuilder(
+      MergeState mergeState,
+      FieldInfo fieldInfo,
+      RandomAccessVectorValues<T> floatVectorValues,
+      int initializerIndex)
+      throws IOException {
+    if (initializerIndex == -1) {
+      return HnswGraphBuilder.create(
+          floatVectorValues,
+          fieldInfo.getVectorEncoding(),
+          fieldInfo.getVectorSimilarityFunction(),
+          M,
+          beamWidth,
+          HnswGraphBuilder.randSeed);
+    }
+
+    HnswGraph initializerGraph =
+        getHnswGraphFromReader(fieldInfo.name, mergeState.knnVectorsReaders[initializerIndex]);
+    Map<Integer, Integer> ordinalMapper =
+        getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
+    return HnswGraphBuilder.create(
+        floatVectorValues,
+        fieldInfo.getVectorEncoding(),
+        fieldInfo.getVectorSimilarityFunction(),
+        M,
+        beamWidth,
+        HnswGraphBuilder.randSeed,
+        initializerGraph,
+        ordinalMapper);
+  }
+
+  private int selectGraphForInitialization(MergeState mergeState, FieldInfo fieldInfo)
+      throws IOException {
+    // Find the KnnVectorReader with the most docs that meets the following criteria:
+    //  1. Does not contain any deleted docs
+    //  2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader
+    // If no readers exist that meet this criteria, return -1. If they do, return their index in
+    // merge state
+    int maxCandidateVectorCount = 0;
+    int initializerIndex = -1;
+
+    for (int i = 0; i < mergeState.liveDocs.length; i++) {
+      KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i];
+      if (mergeState.knnVectorsReaders[i] instanceof PerFieldKnnVectorsFormat.FieldsReader) {
+        PerFieldKnnVectorsFormat.FieldsReader candidateReader =
+            (PerFieldKnnVectorsFormat.FieldsReader) mergeState.knnVectorsReaders[i];
+        currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
+      }
+
+      if (!allMatch(mergeState.liveDocs[i])
+          || !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader)) {
+        continue;
+      }
+      Lucene95HnswVectorsReader candidateReader = (Lucene95HnswVectorsReader) currKnnVectorsReader;
+
+      int candidateVectorCount = 0;
+      switch (fieldInfo.getVectorEncoding()) {
+        case BYTE:
+          ByteVectorValues byteVectorValues = candidateReader.getByteVectorValues(fieldInfo.name);
+          if (byteVectorValues == null) {
+            continue;
+          }
+          candidateVectorCount = byteVectorValues.size();
+          break;
+        case FLOAT32:
+          FloatVectorValues vectorValues = candidateReader.getFloatVectorValues(fieldInfo.name);
+          if (vectorValues == null) {
+            continue;
+          }
+          candidateVectorCount = vectorValues.size();
+          break;
+      }
+
+      if (candidateVectorCount > maxCandidateVectorCount) {
+        maxCandidateVectorCount = candidateVectorCount;
+        initializerIndex = i;
+      }
+    }
+    return initializerIndex;
+  }
+
+  private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader)
+      throws IOException {
+    if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader) {
+      PerFieldKnnVectorsFormat.FieldsReader perFieldReader =
+          (PerFieldKnnVectorsFormat.FieldsReader) knnVectorsReader;
+      if (perFieldReader.getFieldReader(fieldName) instanceof Lucene95HnswVectorsReader) {
+        Lucene95HnswVectorsReader fieldReader =
+            (Lucene95HnswVectorsReader) perFieldReader.getFieldReader(fieldName);
+        return fieldReader.getGraph(fieldName);
+      }
+    }
+
+    if (knnVectorsReader instanceof Lucene95HnswVectorsReader) {
+      return ((Lucene95HnswVectorsReader) knnVectorsReader).getGraph(fieldName);
+    }
+
+    // We should not reach here because knnVectorsReader's type is checked in
+    // selectGraphForInitialization
+    throw new IllegalArgumentException(
+        "Invalid KnnVectorsReader type for field: "
+            + fieldName
+            + ". Must be Lucene95HnswVectorsReader or newer");
+  }
+
+  private Map<Integer, Integer> getOldToNewOrdinalMap(
+      MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws IOException {
+
+    DocIdSetIterator initializerIterator = null;
+
+    switch (fieldInfo.getVectorEncoding()) {
+      case BYTE:
+        initializerIterator =
+            mergeState.knnVectorsReaders[initializerIndex].getByteVectorValues(fieldInfo.name);
+        break;
+      case FLOAT32:
+        initializerIterator =
+            mergeState.knnVectorsReaders[initializerIndex].getFloatVectorValues(fieldInfo.name);
+        break;
+    }
+
+    MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex];
+
+    Map<Integer, Integer> newIdToOldOrdinal = new HashMap<>();
+    int oldOrd = 0;
+    int maxNewDocID = -1;
+    for (int oldId = initializerIterator.nextDoc();
+        oldId != NO_MORE_DOCS;
+        oldId = initializerIterator.nextDoc()) {
+      if (isCurrentVectorNull(initializerIterator)) {
+        continue;
+      }
+      int newId = initializerDocMap.get(oldId);
+      maxNewDocID = Math.max(newId, maxNewDocID);
+      newIdToOldOrdinal.put(newId, oldOrd);
+      oldOrd++;
+    }
+
+    if (maxNewDocID == -1) {
+      return Collections.emptyMap();
+    }
+
+    Map<Integer, Integer> oldToNewOrdinalMap = new HashMap<>();
+
+    DocIdSetIterator vectorIterator = null;
+    switch (fieldInfo.getVectorEncoding()) {
+      case BYTE:
+        vectorIterator = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
+        break;
+      case FLOAT32:
+        vectorIterator = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
+        break;
+    }
+
+    int newOrd = 0;
+    for (int newDocId = vectorIterator.nextDoc();
+        newDocId <= maxNewDocID;
+        newDocId = vectorIterator.nextDoc()) {
+      if (isCurrentVectorNull(vectorIterator)) {
+        continue;
+      }
+
+      if (newIdToOldOrdinal.containsKey(newDocId)) {
+        oldToNewOrdinalMap.put(newIdToOldOrdinal.get(newDocId), newOrd);
+      }
+      newOrd++;
+    }
+
+    return oldToNewOrdinalMap;
+  }
+
+  private boolean isCurrentVectorNull(DocIdSetIterator docIdSetIterator) throws IOException {
+    if (docIdSetIterator instanceof FloatVectorValues) {
+      return ((FloatVectorValues) docIdSetIterator).vectorValue() == null;
+    }
+
+    if (docIdSetIterator instanceof ByteVectorValues) {
+      return ((ByteVectorValues) docIdSetIterator).vectorValue() == null;
+    }
+
+    return true;
+  }
+
+  private boolean allMatch(Bits bits) {
+    if (bits == null) {
+      return true;
+    }
+
+    for (int i = 0; i < bits.length(); i++) {
+      if (!bits.get(i)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   /**
    * @param graph Write the graph in a compressed format
    * @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets.
@@ -762,10 +952,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
       assert docID > lastDocID;
       docsWithField.add(docID);
       vectors.add(copyValue(vectorValue));
-      if (node > 0) {
-        // start at node 1! node 0 is added implicitly, in the constructor
-        hnswGraphBuilder.addGraphNode(node, vectorValue);
-      }
+      hnswGraphBuilder.addGraphNode(node, vectorValue);
       node++;
       lastDocID = docID;
     }
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
index fc7b0be82fb..9086ab55d2e 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
@@ -20,6 +20,8 @@ package org.apache.lucene.util.hnsw;
 import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
 
 import java.io.IOException;
+import java.util.Collection;
+import java.util.Iterator;
 import java.util.NoSuchElementException;
 import java.util.PrimitiveIterator;
 import org.apache.lucene.index.FloatVectorValues;
@@ -115,7 +117,7 @@ public abstract class HnswGraph {
 
         @Override
         public NodesIterator getNodesOnLevel(int level) {
-          return NodesIterator.EMPTY;
+          return ArrayNodesIterator.EMPTY;
         }
       };
 
@@ -123,33 +125,50 @@ public abstract class HnswGraph {
    * Iterator over the graph nodes on a certain level, Iterator also provides the size – the total
    * number of nodes to be iterated over.
    */
-  public static final class NodesIterator implements PrimitiveIterator.OfInt {
-    static NodesIterator EMPTY = new NodesIterator(0);
-
-    private final int[] nodes;
-    private final int size;
-    int cur = 0;
-
-    /** Constructor for iterator based on the nodes array up to the size */
-    public NodesIterator(int[] nodes, int size) {
-      assert nodes != null;
-      assert size <= nodes.length;
-      this.nodes = nodes;
-      this.size = size;
-    }
+  public abstract static class NodesIterator implements PrimitiveIterator.OfInt {
+    protected final int size;
 
     /** Constructor for iterator based on the size */
     public NodesIterator(int size) {
-      this.nodes = null;
       this.size = size;
     }
 
+    /** The number of elements in this iterator * */
+    public int size() {
+      return size;
+    }
+
     /**
      * Consume integers from the iterator and place them into the `dest` array.
      *
      * @param dest where to put the integers
      * @return The number of integers written to `dest`
      */
+    public abstract int consume(int[] dest);
+  }
+
+  /** NodesIterator that accepts nodes as an integer array. */
+  public static class ArrayNodesIterator extends NodesIterator {
+    static NodesIterator EMPTY = new ArrayNodesIterator(0);
+
+    private final int[] nodes;
+    private int cur = 0;
+
+    /** Constructor for iterator based on integer array representing nodes */
+    public ArrayNodesIterator(int[] nodes, int size) {
+      super(size);
+      assert nodes != null;
+      assert size <= nodes.length;
+      this.nodes = nodes;
+    }
+
+    /** Constructor for iterator based on the size */
+    public ArrayNodesIterator(int size) {
+      super(size);
+      this.nodes = null;
+    }
+
+    @Override
     public int consume(int[] dest) {
       if (hasNext() == false) {
         throw new NoSuchElementException();
@@ -182,10 +201,43 @@ public abstract class HnswGraph {
     public boolean hasNext() {
       return cur < size;
     }
+  }
 
-    /** The number of elements in this iterator * */
-    public int size() {
-      return size;
+  /** Nodes iterator based on set representation of nodes. */
+  public static class CollectionNodesIterator extends NodesIterator {
+    Iterator<Integer> nodes;
+
+    /** Constructor for iterator based on collection representing nodes */
+    public CollectionNodesIterator(Collection<Integer> nodes) {
+      super(nodes.size());
+      this.nodes = nodes.iterator();
+    }
+
+    @Override
+    public int consume(int[] dest) {
+      if (hasNext() == false) {
+        throw new NoSuchElementException();
+      }
+
+      int destIndex = 0;
+      while (hasNext() && destIndex < dest.length) {
+        dest[destIndex++] = nextInt();
+      }
+
+      return destIndex;
+    }
+
+    @Override
+    public int nextInt() {
+      if (hasNext() == false) {
+        throw new NoSuchElementException();
+      }
+      return nodes.next();
+    }
+
+    @Override
+    public boolean hasNext() {
+      return nodes.hasNext();
     }
   }
 }
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index fa77de81e73..2c5e84be285 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -18,10 +18,14 @@
 package org.apache.lucene.util.hnsw;
 
 import static java.lang.Math.log;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
 
 import java.io.IOException;
+import java.util.HashSet;
 import java.util.Locale;
+import java.util.Map;
 import java.util.Objects;
+import java.util.Set;
 import java.util.SplittableRandom;
 import java.util.concurrent.TimeUnit;
 import org.apache.lucene.index.VectorEncoding;
@@ -63,6 +67,7 @@ public final class HnswGraphBuilder<T> {
   // we need two sources of vectors in order to perform diversity check comparisons without
   // colliding
   private final RandomAccessVectorValues<T> vectorsCopy;
+  private final Set<Integer> initializedNodes;
 
   public static <T> HnswGraphBuilder<T> create(
       RandomAccessVectorValues<T> vectors,
@@ -75,6 +80,22 @@ public final class HnswGraphBuilder<T> {
     return new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
   }
 
+  public static <T> HnswGraphBuilder<T> create(
+      RandomAccessVectorValues<T> vectors,
+      VectorEncoding vectorEncoding,
+      VectorSimilarityFunction similarityFunction,
+      int M,
+      int beamWidth,
+      long seed,
+      HnswGraph initializerGraph,
+      Map<Integer, Integer> oldToNewOrdinalMap)
+      throws IOException {
+    HnswGraphBuilder<T> hnswGraphBuilder =
+        new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
+    hnswGraphBuilder.initializeFromGraph(initializerGraph, oldToNewOrdinalMap);
+    return hnswGraphBuilder;
+  }
+
   /**
    * Reads all the vectors from vector values, builds a graph connecting them by their dense
    * ordinals, using the given hyperparameter settings, and returns the resulting graph.
@@ -110,8 +131,7 @@ public final class HnswGraphBuilder<T> {
     // normalization factor for level generation; currently not configurable
     this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
     this.random = new SplittableRandom(seed);
-    int levelOfFirstNode = getRandomGraphLevel(ml, random);
-    this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
+    this.hnsw = new OnHeapHnswGraph(M);
     this.graphSearcher =
         new HnswGraphSearcher<>(
             vectorEncoding,
@@ -120,6 +140,7 @@ public final class HnswGraphBuilder<T> {
             new FixedBitSet(this.vectors.size()));
     // in scratch we store candidates in reverse order: worse candidates are first
     scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
+    this.initializedNodes = new HashSet<>();
   }
 
   /**
@@ -142,10 +163,75 @@ public final class HnswGraphBuilder<T> {
     return hnsw;
   }
 
+  /**
+   * Initializes the graph of this builder. Transfers the nodes and their neighbors from the
+   * initializer graph into the graph being produced by this builder, mapping ordinals from the
+   * initializer graph to their new ordinals in this builder's graph. The builder's graph must be
+   * empty before calling this method.
+   *
+   * @param initializerGraph graph used for initialization
+   * @param oldToNewOrdinalMap map for converting from ordinals in the initializerGraph to this
+   *     builder's graph
+   */
+  private void initializeFromGraph(
+      HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) throws IOException {
+    assert hnsw.size() == 0;
+    float[] vectorValue = null;
+    byte[] binaryValue = null;
+    for (int level = 0; level < initializerGraph.numLevels(); level++) {
+      HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
+
+      while (it.hasNext()) {
+        int oldOrd = it.nextInt();
+        int newOrd = oldToNewOrdinalMap.get(oldOrd);
+
+        hnsw.addNode(level, newOrd);
+
+        if (level == 0) {
+          initializedNodes.add(newOrd);
+        }
+
+        switch (this.vectorEncoding) {
+          case FLOAT32:
+            vectorValue = (float[]) vectors.vectorValue(newOrd);
+            break;
+          case BYTE:
+            binaryValue = (byte[]) vectors.vectorValue(newOrd);
+            break;
+        }
+
+        NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd);
+        initializerGraph.seek(level, oldOrd);
+        for (int oldNeighbor = initializerGraph.nextNeighbor();
+            oldNeighbor != NO_MORE_DOCS;
+            oldNeighbor = initializerGraph.nextNeighbor()) {
+          int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor);
+          float score;
+          switch (this.vectorEncoding) {
+            case FLOAT32:
+            default:
+              score =
+                  this.similarityFunction.compare(
+                      vectorValue, (float[]) vectorsCopy.vectorValue(newNeighbor));
+              break;
+            case BYTE:
+              score =
+                  this.similarityFunction.compare(
+                      binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor));
+              break;
+          }
+          newNeighbors.insertSorted(newNeighbor, score);
+        }
+      }
+    }
+  }
+
   private void addVectors(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
     long start = System.nanoTime(), t = start;
-    // start at node 1! node 0 is added implicitly, in the constructor
-    for (int node = 1; node < vectorsToAdd.size(); node++) {
+    for (int node = 0; node < vectorsToAdd.size(); node++) {
+      if (initializedNodes.contains(node)) {
+        continue;
+      }
       addGraphNode(node, vectorsToAdd);
       if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
         t = printGraphBuildStatus(node, start, t);
@@ -167,6 +253,14 @@ public final class HnswGraphBuilder<T> {
     NeighborQueue candidates;
     final int nodeLevel = getRandomGraphLevel(ml, random);
     int curMaxLevel = hnsw.numLevels() - 1;
+
+    // If entrynode is -1, then this should finish without adding neighbors
+    if (hnsw.entryNode() == -1) {
+      for (int level = nodeLevel; level >= 0; level--) {
+        hnsw.addNode(level, node);
+      }
+      return;
+    }
     int[] eps = new int[] {hnsw.entryNode()};
 
     // if a node introduces new levels to the graph, add this new node on new levels
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index 13a338e0d18..4857d5b9d57 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -101,7 +101,12 @@ public class HnswGraphSearcher<T> {
             new NeighborQueue(topK, true),
             new SparseFixedBitSet(vectors.size()));
     NeighborQueue results;
-    int[] eps = new int[] {graph.entryNode()};
+
+    int initialEp = graph.entryNode();
+    if (initialEp == -1) {
+      return new NeighborQueue(1, true);
+    }
+    int[] eps = new int[] {initialEp};
     int numVisited = 0;
     for (int level = graph.numLevels() - 1; level >= 1; level--) {
       results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index 78137c2a630..9862536de08 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -20,10 +20,9 @@ package org.apache.lucene.util.hnsw;
 import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
+import java.util.TreeMap;
 import org.apache.lucene.util.Accountable;
-import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.RamUsageEstimator;
 
 /**
@@ -33,19 +32,20 @@ import org.apache.lucene.util.RamUsageEstimator;
 public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
 
   private int numLevels; // the current number of levels in the graph
-  private int entryNode; // the current graph entry node on the top level
+  private int entryNode; // the current graph entry node on the top level. -1 if not set
 
-  // Nodes by level expressed as the level 0's nodes' ordinals.
-  // As level 0 contains all nodes, nodesByLevel.get(0) is null.
-  private final List<int[]> nodesByLevel;
-
-  // graph is a list of graph levels.
-  // Each level is represented as List<NeighborArray> – nodes' connections on this level.
+  // Level 0 is represented as List<NeighborArray> – nodes' connections on level 0.
   // Each entry in the list has the top maxConn/maxConn0 neighbors of a node. The nodes correspond
   // to vectors
   // added to HnswBuilder, and the node values are the ordinals of those vectors.
   // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
-  private final List<List<NeighborArray>> graph;
+  private final List<NeighborArray> graphLevel0;
+  // Represents levels 1-N. Each level is represented with a TreeMap that maps a levels level 0
+  // ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain
+  // it in this list. However, to avoid changing list indexing, we always will make the first
+  // element
+  // null.
+  private final List<TreeMap<Integer, NeighborArray>> graphUpperLevels;
   private final int nsize;
   private final int nsize0;
 
@@ -53,24 +53,17 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
   private int upto;
   private NeighborArray cur;
 
-  OnHeapHnswGraph(int M, int levelOfFirstNode) {
-    this.numLevels = levelOfFirstNode + 1;
-    this.graph = new ArrayList<>(numLevels);
-    this.entryNode = 0;
+  OnHeapHnswGraph(int M) {
+    this.numLevels = 1; // Implicitly start the graph with a single level
+    this.graphLevel0 = new ArrayList<>();
+    this.entryNode = -1; // Entry node should be negative until a node is added
     // Neighbours' size on upper levels (nsize) and level 0 (nsize0)
     // We allocate extra space for neighbours, but then prune them to keep allowed maximum
     this.nsize = M + 1;
     this.nsize0 = (M * 2 + 1);
-    for (int l = 0; l < numLevels; l++) {
-      graph.add(new ArrayList<>());
-      graph.get(l).add(new NeighborArray(l == 0 ? nsize0 : nsize, true));
-    }
 
-    this.nodesByLevel = new ArrayList<>(numLevels);
-    nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes
-    for (int l = 1; l < numLevels; l++) {
-      nodesByLevel.add(new int[] {0});
-    }
+    this.graphUpperLevels = new ArrayList<>(numLevels);
+    graphUpperLevels.add(null); // we don't need this for 0th level, as it contains all nodes
   }
 
   /**
@@ -81,49 +74,52 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
    */
   public NeighborArray getNeighbors(int level, int node) {
     if (level == 0) {
-      return graph.get(level).get(node);
+      return graphLevel0.get(node);
     }
-    int nodeIndex = Arrays.binarySearch(nodesByLevel.get(level), 0, graph.get(level).size(), node);
-    assert nodeIndex >= 0;
-    return graph.get(level).get(nodeIndex);
+    TreeMap<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
+    assert levelMap.containsKey(node);
+    return levelMap.get(node);
   }
 
   @Override
   public int size() {
-    return graph.get(0).size(); // all nodes are located on the 0th level
+    return graphLevel0.size(); // all nodes are located on the 0th level
   }
 
   /**
-   * Add node on the given level
+   * Add node on the given level. Nodes can be inserted out of order, but it requires that the nodes
+   * preceded by the node inserted out of order are eventually added.
    *
    * @param level level to add a node on
    * @param node the node to add, represented as an ordinal on the level 0.
    */
   public void addNode(int level, int node) {
+    if (entryNode == -1) {
+      entryNode = node;
+    }
+
     if (level > 0) {
       // if the new node introduces a new level, add more levels to the graph,
       // and make this node the graph's new entry point
       if (level >= numLevels) {
         for (int i = numLevels; i <= level; i++) {
-          graph.add(new ArrayList<>());
-          nodesByLevel.add(new int[] {node});
+          graphUpperLevels.add(new TreeMap<>());
         }
         numLevels = level + 1;
         entryNode = node;
-      } else {
-        // Add this node id to this level's nodes
-        int[] nodes = nodesByLevel.get(level);
-        int idx = graph.get(level).size();
-        if (idx < nodes.length) {
-          nodes[idx] = node;
-        } else {
-          nodes = ArrayUtil.grow(nodes);
-          nodes[idx] = node;
-          nodesByLevel.set(level, nodes);
-        }
+      }
+
+      graphUpperLevels.get(level).put(node, new NeighborArray(nsize, true));
+    } else {
+      // Add nodes all the way up to and including "node" in the new graph on level 0. This will
+      // cause the size of the
+      // graph to differ from the number of nodes added to the graph. The size of the graph and the
+      // number of nodes
+      // added will only be in sync once all nodes from 0...last_node are added into the graph.
+      while (node >= graphLevel0.size()) {
+        graphLevel0.add(new NeighborArray(nsize0, true));
       }
     }
-    graph.get(level).add(new NeighborArray(level == 0 ? nsize0 : nsize, true));
   }
 
   @Override
@@ -164,9 +160,9 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
   @Override
   public NodesIterator getNodesOnLevel(int level) {
     if (level == 0) {
-      return new NodesIterator(size());
+      return new ArrayNodesIterator(size());
     } else {
-      return new NodesIterator(nodesByLevel.get(level), graph.get(level).size());
+      return new CollectionNodesIterator(graphUpperLevels.get(level).keySet());
     }
   }
 
@@ -184,19 +180,26 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
             + Integer.BYTES * 2;
     long total = 0;
     for (int l = 0; l < numLevels; l++) {
-      int numNodesOnLevel = graph.get(l).size();
       if (l == 0) {
         total +=
-            numNodesOnLevel * neighborArrayBytes0
+            graphLevel0.size() * neighborArrayBytes0
                 + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
       } else {
+        long numNodesOnLevel = graphUpperLevels.get(l).size();
+
+        // For levels > 0, we represent the graph structure with a tree map.
+        // A single node in the tree contains 3 references (left root, right root, value) as well
+        // as an Integer for the key and 1 extra byte for the color of the node (this is actually 1
+        // bit, but
+        // because we do not have that granularity, we set to 1 byte). In addition, we include 1
+        // more reference for
+        // the tree map itself.
         total +=
-            nodesByLevel.get(l).length * Integer.BYTES
-                + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
-                + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for nodesByLevel
-        total +=
-            numNodesOnLevel * neighborArrayBytes
-                + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
+            numNodesOnLevel * (3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + Integer.BYTES + 1)
+                + RamUsageEstimator.NUM_BYTES_OBJECT_REF;
+
+        // Add the size neighbor of each node
+        total += numNodesOnLevel * neighborArrayBytes;
       }
     }
     return total;
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index bba92fab222..08f089430ba 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -48,14 +48,12 @@ import org.apache.lucene.search.SearcherManager;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.util.LuceneTestCase;
-import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.IOUtils;
 import org.apache.lucene.util.VectorUtil;
 import org.apache.lucene.util.hnsw.HnswGraph;
 import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
-import org.apache.lucene.util.hnsw.HnswGraphBuilder;
 import org.junit.After;
 import org.junit.Before;
 
@@ -179,21 +177,6 @@ public class TestKnnGraph extends LuceneTestCase {
     }
   }
 
-  /**
-   * Verify that we get the *same* graph by indexing one segment as we do by indexing two segments
-   * and merging.
-   */
-  public void testMergeProducesSameGraph() throws Exception {
-    long seed = random().nextLong();
-    int numDoc = atLeast(100);
-    int dimension = atLeast(10);
-    float[][] values = randomVectors(numDoc, dimension);
-    int mergePoint = random().nextInt(numDoc);
-    int[][][] mergedGraph = getIndexedGraph(values, mergePoint, seed);
-    int[][][] singleSegmentGraph = getIndexedGraph(values, -1, seed);
-    assertGraphEquals(singleSegmentGraph, mergedGraph);
-  }
-
   /** Test writing and reading of multiple vector fields * */
   public void testMultipleVectorFields() throws Exception {
     int numVectorFields = randomIntBetween(2, 5);
@@ -227,52 +210,6 @@ public class TestKnnGraph extends LuceneTestCase {
     }
   }
 
-  private void assertGraphEquals(int[][][] expected, int[][][] actual) {
-    assertEquals("graph sizes differ", expected.length, actual.length);
-    for (int level = 0; level < expected.length; level++) {
-      for (int node = 0; node < expected[level].length; node++) {
-        assertArrayEquals("difference at ord=" + node, expected[level][node], actual[level][node]);
-      }
-    }
-  }
-
-  /**
-   * Return a naive representation of an HNSW graph as a 3 dimensional array: 1st dim represents a
-   * graph layer. Each layer contains an array of arrays – a list of nodes and for each node a list
-   * of the node's neighbours. 2nd dim represents a node on a layer, and contains the node's
-   * neighbourhood, or {@code null} if a node is not present on this layer. 3rd dim represents
-   * neighbours of a node.
-   */
-  private int[][][] getIndexedGraph(float[][] values, int mergePoint, long seed)
-      throws IOException {
-    HnswGraphBuilder.randSeed = seed;
-    int[][][] graph;
-    try (Directory dir = newDirectory()) {
-      IndexWriterConfig iwc = newIndexWriterConfig();
-      iwc.setMergePolicy(new LogDocMergePolicy()); // for predictable segment ordering when merging
-      iwc.setCodec(codec); // don't use SimpleTextCodec
-      try (IndexWriter iw = new IndexWriter(dir, iwc)) {
-        for (int i = 0; i < values.length; i++) {
-          add(iw, i, values[i]);
-          if (i == mergePoint) {
-            // flush proactively to create a segment
-            iw.flush();
-          }
-        }
-        iw.forceMerge(1);
-      }
-      try (IndexReader reader = DirectoryReader.open(dir)) {
-        PerFieldKnnVectorsFormat.FieldsReader perFieldReader =
-            (PerFieldKnnVectorsFormat.FieldsReader)
-                ((CodecReader) getOnlyLeafReader(reader)).getVectorReader();
-        Lucene95HnswVectorsReader vectorReader =
-            (Lucene95HnswVectorsReader) perFieldReader.getFieldReader(KNN_GRAPH_FIELD);
-        graph = copyGraph(vectorReader.getGraph(KNN_GRAPH_FIELD));
-      }
-    }
-    return graph;
-  }
-
   private float[][] randomVectors(int numDoc, int dimension) {
     float[][] values = new float[numDoc][];
     for (int i = 0; i < numDoc; i++) {
@@ -297,27 +234,6 @@ public class TestKnnGraph extends LuceneTestCase {
     return value;
   }
 
-  int[][][] copyGraph(HnswGraph graphValues) throws IOException {
-    int[][][] graph = new int[graphValues.numLevels()][][];
-    int size = graphValues.size();
-    int[] scratch = new int[M * 2];
-
-    for (int level = 0; level < graphValues.numLevels(); level++) {
-      NodesIterator nodesItr = graphValues.getNodesOnLevel(level);
-      graph[level] = new int[size][];
-      while (nodesItr.hasNext()) {
-        int node = nodesItr.nextInt();
-        graphValues.seek(level, node);
-        int n, count = 0;
-        while ((n = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
-          scratch[count++] = n;
-        }
-        graph[level][node] = ArrayUtil.copyOfSubArray(scratch, 0, count);
-      }
-    }
-    return graph;
-  }
-
   /** Verify that searching does something reasonable */
   public void testSearch() throws Exception {
     // We can't use dot product here since the vectors are laid out on a grid, not a sphere.
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
index 7903bcc6160..63aebc40dd7 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
@@ -25,10 +25,14 @@ import com.carrotsearch.randomizedtesting.RandomizedTest;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Random;
 import java.util.Set;
+import java.util.stream.Collectors;
 import org.apache.lucene.codecs.KnnVectorsFormat;
 import org.apache.lucene.codecs.lucene95.Lucene95Codec;
 import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
@@ -84,6 +88,12 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
   abstract AbstractMockVectorValues<T> vectorValues(LeafReader reader, String fieldName)
       throws IOException;
 
+  abstract AbstractMockVectorValues<T> vectorValues(
+      int size,
+      int dimension,
+      AbstractMockVectorValues<T> pregeneratedVectorValues,
+      int pregeneratedOffset);
+
   abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction);
 
   abstract RandomAccessVectorValues<T> circularVectorValues(int nDoc);
@@ -449,6 +459,238 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
     }
   }
 
+  public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException {
+    int maxNumLevels = randomIntBetween(2, 10);
+    int nodeCount = randomIntBetween(1, 100);
+
+    List<List<Integer>> nodesPerLevel = new ArrayList<>();
+    for (int i = 0; i < maxNumLevels; i++) {
+      nodesPerLevel.add(new ArrayList<>());
+    }
+
+    int numLevels = 0;
+    for (int currNode = 0; currNode < nodeCount; currNode++) {
+      int nodeMaxLevel = random().nextInt(maxNumLevels) + 1;
+      numLevels = Math.max(numLevels, nodeMaxLevel);
+      for (int currLevel = 0; currLevel < nodeMaxLevel; currLevel++) {
+        nodesPerLevel.get(currLevel).add(currNode);
+      }
+    }
+
+    OnHeapHnswGraph topDownOrderReversedHnsw = new OnHeapHnswGraph(10);
+    for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
+      List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
+      int currLevelNodesSize = currLevelNodes.size();
+      for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
+        topDownOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
+      }
+    }
+
+    OnHeapHnswGraph bottomUpOrderReversedHnsw = new OnHeapHnswGraph(10);
+    for (int currLevel = 0; currLevel < numLevels; currLevel++) {
+      List<Integer> currLevelNodes = nodesPerLevel.get(currLevel);
+      int currLevelNodesSize = currLevelNodes.size();
+      for (int currNodeInd = currLevelNodesSize - 1; currNodeInd >= 0; currNodeInd--) {
+        bottomUpOrderReversedHnsw.addNode(currLevel, currLevelNodes.get(currNodeInd));
+      }
+    }
+
+    OnHeapHnswGraph topDownOrderRandomHnsw = new OnHeapHnswGraph(10);
+    for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) {
+      List<Integer> currLevelNodes = new ArrayList<>(nodesPerLevel.get(currLevel));
+      Collections.shuffle(currLevelNodes, random());
+      for (Integer currNode : currLevelNodes) {
+        topDownOrderRandomHnsw.addNode(currLevel, currNode);
+      }
+    }
+
+    OnHeapHnswGraph bottomUpExpectedHnsw = new OnHeapHnswGraph(10);
+    for (int currLevel = 0; currLevel < numLevels; currLevel++) {
+      for (Integer currNode : nodesPerLevel.get(currLevel)) {
+        bottomUpExpectedHnsw.addNode(currLevel, currNode);
+      }
+    }
+
+    assertEquals(nodeCount, bottomUpExpectedHnsw.getNodesOnLevel(0).size());
+    for (Integer node : nodesPerLevel.get(0)) {
+      assertEquals(0, bottomUpExpectedHnsw.getNeighbors(0, node).size());
+    }
+
+    for (int currLevel = 1; currLevel < numLevels; currLevel++) {
+      NodesIterator nodesIterator = bottomUpExpectedHnsw.getNodesOnLevel(currLevel);
+      List<Integer> expectedNodesOnLevel = nodesPerLevel.get(currLevel);
+      assertEquals(expectedNodesOnLevel.size(), nodesIterator.size());
+      for (Integer expectedNode : expectedNodesOnLevel) {
+        int currentNode = nodesIterator.nextInt();
+        assertEquals(expectedNode.intValue(), currentNode);
+        assertEquals(0, bottomUpExpectedHnsw.getNeighbors(currLevel, currentNode).size());
+      }
+    }
+
+    assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw);
+    assertGraphEqual(bottomUpExpectedHnsw, bottomUpOrderReversedHnsw);
+    assertGraphEqual(bottomUpExpectedHnsw, topDownOrderRandomHnsw);
+  }
+
+  public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws IOException {
+    int totalSize = atLeast(100);
+    int initializerSize = random().nextInt(totalSize - 5) + 5;
+    int docIdOffset = 0;
+    int dim = atLeast(10);
+    long seed = random().nextLong();
+
+    AbstractMockVectorValues<T> initializerVectors = vectorValues(initializerSize, dim);
+    HnswGraphBuilder<T> initializerBuilder =
+        HnswGraphBuilder.create(
+            initializerVectors, getVectorEncoding(), similarityFunction, 10, 30, seed);
+
+    OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy());
+    AbstractMockVectorValues<T> finalVectorValues =
+        vectorValues(totalSize, dim, initializerVectors, docIdOffset);
+
+    Map<Integer, Integer> initializerOrdMap =
+        createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
+
+    HnswGraphBuilder<T> finalBuilder =
+        HnswGraphBuilder.create(
+            finalVectorValues,
+            getVectorEncoding(),
+            similarityFunction,
+            10,
+            30,
+            seed,
+            initializerGraph,
+            initializerOrdMap);
+
+    // When offset is 0, the graphs should be identical before vectors are added
+    assertGraphEqual(initializerGraph, finalBuilder.getGraph());
+
+    OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy());
+    assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
+  }
+
+  public void testHnswGraphBuilderInitializationFromGraph_withNonZeroOffset() throws IOException {
+    int totalSize = atLeast(100);
+    int initializerSize = random().nextInt(totalSize - 5) + 5;
+    int docIdOffset = random().nextInt(totalSize - initializerSize) + 1;
+    int dim = atLeast(10);
+    long seed = random().nextLong();
+
+    AbstractMockVectorValues<T> initializerVectors = vectorValues(initializerSize, dim);
+    HnswGraphBuilder<T> initializerBuilder =
+        HnswGraphBuilder.create(
+            initializerVectors.copy(), getVectorEncoding(), similarityFunction, 10, 30, seed);
+    OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.copy());
+    AbstractMockVectorValues<T> finalVectorValues =
+        vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset);
+    Map<Integer, Integer> initializerOrdMap =
+        createOffsetOrdinalMap(initializerSize, finalVectorValues.copy(), docIdOffset);
+
+    HnswGraphBuilder<T> finalBuilder =
+        HnswGraphBuilder.create(
+            finalVectorValues,
+            getVectorEncoding(),
+            similarityFunction,
+            10,
+            30,
+            seed,
+            initializerGraph,
+            initializerOrdMap);
+
+    assertGraphInitializedFromGraph(finalBuilder.getGraph(), initializerGraph, initializerOrdMap);
+
+    // Confirm that the graph is appropriately constructed by checking that the nodes in the old
+    // graph are present in the levels of the new graph
+    OnHeapHnswGraph finalGraph = finalBuilder.build(finalVectorValues.copy());
+    assertGraphContainsGraph(finalGraph, initializerGraph, initializerOrdMap);
+  }
+
+  private void assertGraphContainsGraph(
+      HnswGraph g, HnswGraph h, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
+    for (int i = 0; i < h.numLevels(); i++) {
+      int[] finalGraphNodesOnLevel = nodesIteratorToArray(g.getNodesOnLevel(i));
+      int[] initializerGraphNodesOnLevel =
+          mapArrayAndSort(nodesIteratorToArray(h.getNodesOnLevel(i)), oldToNewOrdMap);
+      int overlap = computeOverlap(finalGraphNodesOnLevel, initializerGraphNodesOnLevel);
+      assertEquals(initializerGraphNodesOnLevel.length, overlap);
+    }
+  }
+
+  private void assertGraphInitializedFromGraph(
+      HnswGraph g, HnswGraph h, Map<Integer, Integer> oldToNewOrdMap) throws IOException {
+    assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
+    // Confirm that the size of the new graph includes all nodes up to an including the max new
+    // ordinal in the old to
+    // new ordinal mapping
+    assertEquals(
+        "the number of nodes in the graphs are different!",
+        g.size(),
+        Collections.max(oldToNewOrdMap.values()) + 1);
+
+    // assert the nodes from the previous graph are successfully to levels > 0 in the new graph
+    for (int level = 1; level < g.numLevels(); level++) {
+      NodesIterator nodesOnLevel = g.getNodesOnLevel(level);
+      NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level);
+      while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) {
+        int node = nodesOnLevel.nextInt();
+        int node2 = oldToNewOrdMap.get(nodesOnLevel2.nextInt());
+        assertEquals("nodes in the graphs are different", node, node2);
+      }
+    }
+
+    // assert that the neighbors from the old graph are successfully transferred to the new graph
+    for (int level = 0; level < g.numLevels(); level++) {
+      NodesIterator nodesOnLevel = h.getNodesOnLevel(level);
+      while (nodesOnLevel.hasNext()) {
+        int node = nodesOnLevel.nextInt();
+        g.seek(level, oldToNewOrdMap.get(node));
+        h.seek(level, node);
+        assertEquals(
+            "arcs differ for node " + node,
+            getNeighborNodes(g),
+            getNeighborNodes(h).stream().map(oldToNewOrdMap::get).collect(Collectors.toSet()));
+      }
+    }
+  }
+
+  private Map<Integer, Integer> createOffsetOrdinalMap(
+      int docIdSize, AbstractMockVectorValues<T> totalVectorValues, int docIdOffset) {
+    // Compute the offset for the ordinal map to be the number of non-null vectors in the total
+    // vector values
+    // before the docIdOffset
+    int ordinalOffset = 0;
+    while (totalVectorValues.nextDoc() < docIdOffset) {
+      ordinalOffset++;
+    }
+
+    Map<Integer, Integer> offsetOrdinalMap = new HashMap<>();
+    for (int curr = 0;
+        totalVectorValues.docID() < docIdOffset + docIdSize;
+        totalVectorValues.nextDoc()) {
+      offsetOrdinalMap.put(curr, ordinalOffset + curr++);
+    }
+
+    return offsetOrdinalMap;
+  }
+
+  private int[] nodesIteratorToArray(NodesIterator nodesIterator) {
+    int[] arr = new int[nodesIterator.size()];
+    int i = 0;
+    while (nodesIterator.hasNext()) {
+      arr[i++] = nodesIterator.nextInt();
+    }
+    return arr;
+  }
+
+  private int[] mapArrayAndSort(int[] arr, Map<Integer, Integer> map) {
+    int[] mappedA = new int[arr.length];
+    for (int i = 0; i < arr.length; i++) {
+      mappedA[i] = map.get(arr[i]);
+    }
+    Arrays.sort(mappedA);
+    return mappedA;
+  }
+
   @SuppressWarnings("unchecked")
   public void testVisitedLimit() throws IOException {
     int nDoc = 500;
@@ -560,8 +802,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
         HnswGraphBuilder.create(
             vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt());
     // node 0 is added by the builder constructor
-    // builder.addGraphNode(vectors.vectorValue(0));
     RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
+    builder.addGraphNode(0, vectorsCopy);
     builder.addGraphNode(1, vectorsCopy);
     builder.addGraphNode(2, vectorsCopy);
     // now every node has tried to attach every other node as a neighbor, but
@@ -615,9 +857,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
     HnswGraphBuilder<T> builder =
         HnswGraphBuilder.create(
             vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
-    // node 0 is added by the builder constructor
-    // builder.addGraphNode(vectors.vectorValue(0));
     RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
+    builder.addGraphNode(0, vectorsCopy);
     builder.addGraphNode(1, vectorsCopy);
     builder.addGraphNode(2, vectorsCopy);
     assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
@@ -648,9 +889,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
     HnswGraphBuilder<T> builder =
         HnswGraphBuilder.create(
             vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
-    // node 0 is added by the builder constructor
-    // builder.addGraphNode(vectors.vectorValue(0));
     RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
+    builder.addGraphNode(0, vectorsCopy);
     builder.addGraphNode(1, vectorsCopy);
     builder.addGraphNode(2, vectorsCopy);
     assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java
index 258864ade7d..3a2d92ff92a 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java
@@ -85,6 +85,34 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
     return MockByteVectorValues.fromValues(bValues);
   }
 
+  @Override
+  AbstractMockVectorValues<byte[]> vectorValues(
+      int size,
+      int dimension,
+      AbstractMockVectorValues<byte[]> pregeneratedVectorValues,
+      int pregeneratedOffset) {
+    byte[][] vectors = new byte[size][];
+    byte[][] randomVectors =
+        createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, random());
+
+    for (int i = 0; i < pregeneratedOffset; i++) {
+      vectors[i] = randomVectors[i];
+    }
+
+    int currentDoc;
+    while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) {
+      vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc];
+    }
+
+    for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length;
+        i < vectors.length;
+        i++) {
+      vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length];
+    }
+
+    return MockByteVectorValues.fromValues(vectors);
+  }
+
   @Override
   AbstractMockVectorValues<byte[]> vectorValues(LeafReader reader, String fieldName)
       throws IOException {
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java
index 16f2e7330e2..5dda5bf0a83 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java
@@ -79,6 +79,35 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
     return MockVectorValues.fromValues(vectors);
   }
 
+  @Override
+  AbstractMockVectorValues<float[]> vectorValues(
+      int size,
+      int dimension,
+      AbstractMockVectorValues<float[]> pregeneratedVectorValues,
+      int pregeneratedOffset) {
+    float[][] vectors = new float[size][];
+    float[][] randomVectors =
+        createRandomFloatVectors(
+            size - pregeneratedVectorValues.values.length, dimension, random());
+
+    for (int i = 0; i < pregeneratedOffset; i++) {
+      vectors[i] = randomVectors[i];
+    }
+
+    int currentDoc;
+    while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) {
+      vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc];
+    }
+
+    for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length;
+        i < vectors.length;
+        i++) {
+      vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length];
+    }
+
+    return MockVectorValues.fromValues(vectors);
+  }
+
   @Override
   Field knnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
     return new KnnFloatVectorField(name, vector, similarityFunction);