You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ju...@apache.org on 2022/01/18 21:53:13 UTC

[lucene] branch main updated: LUCENE-10375: Write merged vectors to file before building graph (#601)

This is an automated email from the ASF dual-hosted git repository.

julietibs pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new dfca9a5  LUCENE-10375: Write merged vectors to file before building graph (#601)
dfca9a5 is described below

commit dfca9a5608526347f95ec6721bad8dd6f6de1c9f
Author: Julie Tibshirani <ju...@apache.org>
AuthorDate: Tue Jan 18 13:53:05 2022 -0800

    LUCENE-10375: Write merged vectors to file before building graph (#601)
    
    When merging segments together, the `KnnVectorsWriter` creates a `VectorValues`
    instance with a merged view of all the segments' vectors. This merged instance
    is used when constructing the new HNSW graph. Graph building needs random
    access, and the merged VectorValues support this by mapping from merged
    ordinals to segments and segment ordinals. This mapping can add significant
    overhead when building the graph.
    
    This change updates the HNSW merging logic to first write the combined segment
    vectors to a file, then use that the file to build the graph. This helps speed
    up segment merging, and also lets us simplify `VectorValuesMerger`, which
    provides the merged view of vector values.
---
 lucene/CHANGES.txt                                 |   3 +
 .../org/apache/lucene/codecs/KnnVectorsWriter.java | 239 ++++++---------------
 .../codecs/lucene90/Lucene90HnswVectorsReader.java |  35 +--
 .../codecs/lucene90/Lucene90HnswVectorsWriter.java | 164 +++++++++++---
 .../codecs/perfield/PerFieldKnnVectorsFormat.java  |  29 +++
 .../perfield/TestPerFieldKnnVectorsFormat.java     |   9 +
 .../asserting/AssertingKnnVectorsFormat.java       |   6 +
 7 files changed, 259 insertions(+), 226 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 9c2b2ec..4333616 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -163,6 +163,9 @@ Optimizations
 * LUCENE-10379: Count directly into the dense values array in FastTaxonomyFacetCounts#countAll.
   (Guo Feng, Greg Miller)
 
+* LUCENE-10375: Speed up HNSW vectors merge by first writing combined vector
+  data to a file. (Julie Tibshirani, Adrien Grand)
+
 Changes in runtime behavior
 ---------------------
 
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
index 4afa933..2c32fae 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
@@ -17,19 +17,13 @@
 
 package org.apache.lucene.codecs;
 
-import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
-
 import java.io.Closeable;
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import org.apache.lucene.index.DocIDMerger;
 import org.apache.lucene.index.FieldInfo;
 import org.apache.lucene.index.MergeState;
-import org.apache.lucene.index.RandomAccessVectorValues;
-import org.apache.lucene.index.RandomAccessVectorValuesProducer;
-import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.index.VectorValues;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.util.Bits;
@@ -48,7 +42,11 @@ public abstract class KnnVectorsWriter implements Closeable {
   /** Called once at the end before close */
   public abstract void finish() throws IOException;
 
-  /** Merge the vector values from multiple segments, for all fields */
+  /**
+   * Merges the segment vectors for all fields. This default implementation delegates to {@link
+   * #writeField}, passing a {@link KnnVectorsReader} that combines the vector values and ignores
+   * deleted documents.
+   */
   public void merge(MergeState mergeState) throws IOException {
     for (int i = 0; i < mergeState.fieldInfos.length; i++) {
       KnnVectorsReader reader = mergeState.knnVectorsReaders[i];
@@ -57,142 +55,97 @@ public abstract class KnnVectorsWriter implements Closeable {
         reader.checkIntegrity();
       }
     }
+
     for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
       if (fieldInfo.hasVectorValues()) {
-        mergeVectors(fieldInfo, mergeState);
-      }
-    }
-    finish();
-  }
+        if (mergeState.infoStream.isEnabled("VV")) {
+          mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
+        }
 
-  private void mergeVectors(FieldInfo mergeFieldInfo, final MergeState mergeState)
-      throws IOException {
-    if (mergeState.infoStream.isEnabled("VV")) {
-      mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
-    }
-    // Create a new VectorValues by iterating over the sub vectors, mapping the resulting
-    // docids using docMaps in the mergeState.
-    writeField(
-        mergeFieldInfo,
-        new KnnVectorsReader() {
-          @Override
-          public long ramBytesUsed() {
-            return 0;
-          }
+        writeField(
+            fieldInfo,
+            new KnnVectorsReader() {
+              @Override
+              public long ramBytesUsed() {
+                return 0;
+              }
 
-          @Override
-          public void close() throws IOException {
-            throw new UnsupportedOperationException();
-          }
+              @Override
+              public void close() {
+                throw new UnsupportedOperationException();
+              }
 
-          @Override
-          public void checkIntegrity() throws IOException {
-            throw new UnsupportedOperationException();
-          }
+              @Override
+              public void checkIntegrity() {
+                throw new UnsupportedOperationException();
+              }
 
-          @Override
-          public VectorValues getVectorValues(String field) throws IOException {
-            List<VectorValuesSub> subs = new ArrayList<>();
-            int dimension = -1;
-            VectorSimilarityFunction similarityFunction = null;
-            int nonEmptySegmentIndex = 0;
-            for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
-              KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
-              if (knnVectorsReader != null) {
-                if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
-                  int segmentDimension = mergeFieldInfo.getVectorDimension();
-                  VectorSimilarityFunction segmentSimilarityFunction =
-                      mergeFieldInfo.getVectorSimilarityFunction();
-                  if (dimension == -1) {
-                    dimension = segmentDimension;
-                    similarityFunction = mergeFieldInfo.getVectorSimilarityFunction();
-                  } else if (dimension != segmentDimension) {
-                    throw new IllegalStateException(
-                        "Varying dimensions for vector-valued field "
-                            + mergeFieldInfo.name
-                            + ": "
-                            + dimension
-                            + "!="
-                            + segmentDimension);
-                  } else if (similarityFunction != segmentSimilarityFunction) {
-                    throw new IllegalStateException(
-                        "Varying similarity functions for vector-valued field "
-                            + mergeFieldInfo.name
-                            + ": "
-                            + similarityFunction
-                            + "!="
-                            + segmentSimilarityFunction);
-                  }
-                  VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name);
-                  if (values != null) {
-                    subs.add(
-                        new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values));
-                  }
-                }
+              @Override
+              public VectorValues getVectorValues(String field) throws IOException {
+                return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
               }
-            }
-            return new VectorValuesMerger(subs, mergeState);
-          }
 
-          @Override
-          public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
-              throws IOException {
-            throw new UnsupportedOperationException();
-          }
-        });
+              @Override
+              public TopDocs search(String field, float[] target, int k, Bits acceptDocs) {
+                throw new UnsupportedOperationException();
+              }
+            });
 
-    if (mergeState.infoStream.isEnabled("VV")) {
-      mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
+        if (mergeState.infoStream.isEnabled("VV")) {
+          mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
+        }
+      }
     }
+    finish();
   }
 
   /** Tracks state of one sub-reader that we are merging */
   private static class VectorValuesSub extends DocIDMerger.Sub {
 
     final VectorValues values;
-    final int segmentIndex;
-    int count;
 
-    VectorValuesSub(int segmentIndex, MergeState.DocMap docMap, VectorValues values) {
+    VectorValuesSub(MergeState.DocMap docMap, VectorValues values) {
       super(docMap);
       this.values = values;
-      this.segmentIndex = segmentIndex;
       assert values.docID() == -1;
     }
 
     @Override
     public int nextDoc() throws IOException {
-      int docId = values.nextDoc();
-      if (docId != NO_MORE_DOCS) {
-        // Note: this does count deleted docs since they are present in the to-be-merged segment
-        ++count;
-      }
-      return docId;
+      return values.nextDoc();
     }
   }
 
-  /**
-   * View over multiple VectorValues supporting iterator-style access via DocIdMerger. Maintains a
-   * reverse ordinal mapping for documents having values in order to support random access by dense
-   * ordinal.
-   */
-  private static class VectorValuesMerger extends VectorValues
-      implements RandomAccessVectorValuesProducer {
+  /** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
+  public static class MergedVectorValues extends VectorValues {
     private final List<VectorValuesSub> subs;
     private final DocIDMerger<VectorValuesSub> docIdMerger;
-    private final int[] ordBase;
     private final int cost;
-    private int size;
+    private final int size;
 
     private int docId;
     private VectorValuesSub current;
-    /* For each doc with a vector, record its ord in the segments being merged. This enables random
-     * access into the unmerged segments using the ords from the merged segment.
-     */
-    private int[] ordMap;
-    private int ord;
 
-    VectorValuesMerger(List<VectorValuesSub> subs, MergeState mergeState) throws IOException {
+    /** Returns a merged view over all the segment's {@link VectorValues}. */
+    public static MergedVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
+        throws IOException {
+      assert fieldInfo != null && fieldInfo.hasVectorValues();
+
+      List<VectorValuesSub> subs = new ArrayList<>();
+      for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+        KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+        if (knnVectorsReader != null) {
+          VectorValues values = knnVectorsReader.getVectorValues(fieldInfo.name);
+          if (values != null) {
+            subs.add(new VectorValuesSub(mergeState.docMaps[i], values));
+          }
+        }
+      }
+      return new MergedVectorValues(subs, mergeState);
+    }
+
+    private MergedVectorValues(List<VectorValuesSub> subs, MergeState mergeState)
+        throws IOException {
       this.subs = subs;
       docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
       int totalCost = 0, totalSize = 0;
@@ -200,20 +153,8 @@ public abstract class KnnVectorsWriter implements Closeable {
         totalCost += sub.values.cost();
         totalSize += sub.values.size();
       }
-      /* This size includes deleted docs, but when we iterate over docs here (nextDoc())
-       * we skip deleted docs. So we sneakily update this size once we observe that iteration is complete.
-       * That way by the time we are asked to do random access for graph building, we have a correct size.
-       */
       cost = totalCost;
       size = totalSize;
-      ordMap = new int[size];
-      ordBase = new int[subs.size()];
-      int lastBase = 0;
-      for (int k = 0; k < subs.size(); k++) {
-        int size = subs.get(k).values.size();
-        ordBase[k] = lastBase;
-        lastBase += size;
-      }
       docId = -1;
     }
 
@@ -227,12 +168,8 @@ public abstract class KnnVectorsWriter implements Closeable {
       current = docIdMerger.next();
       if (current == null) {
         docId = NO_MORE_DOCS;
-        /* update the size to reflect the number of *non-deleted* documents seen so we can support
-         * random access. */
-        size = ord;
       } else {
         docId = current.mappedDocID;
-        ordMap[ord++] = ordBase[current.segmentIndex] + current.count - 1;
       }
       return docId;
     }
@@ -248,11 +185,6 @@ public abstract class KnnVectorsWriter implements Closeable {
     }
 
     @Override
-    public RandomAccessVectorValues randomAccess() {
-      return new MergerRandomAccess();
-    }
-
-    @Override
     public int advance(int target) {
       throw new UnsupportedOperationException();
     }
@@ -271,52 +203,5 @@ public abstract class KnnVectorsWriter implements Closeable {
     public int dimension() {
       return subs.get(0).values.dimension();
     }
-
-    class MergerRandomAccess implements RandomAccessVectorValues {
-
-      private final List<RandomAccessVectorValues> raSubs;
-
-      MergerRandomAccess() {
-        raSubs = new ArrayList<>(subs.size());
-        for (VectorValuesSub sub : subs) {
-          if (sub.values instanceof RandomAccessVectorValuesProducer) {
-            raSubs.add(((RandomAccessVectorValuesProducer) sub.values).randomAccess());
-          } else {
-            throw new IllegalStateException(
-                "Cannot merge VectorValues without support for random access");
-          }
-        }
-      }
-
-      @Override
-      public int size() {
-        return size;
-      }
-
-      @Override
-      public int dimension() {
-        return VectorValuesMerger.this.dimension();
-      }
-
-      @Override
-      public float[] vectorValue(int target) throws IOException {
-        int unmappedOrd = ordMap[target];
-        int segmentOrd = Arrays.binarySearch(ordBase, unmappedOrd);
-        if (segmentOrd < 0) {
-          // get the index of the greatest lower bound
-          segmentOrd = -2 - segmentOrd;
-        }
-        while (segmentOrd < ordBase.length - 1 && ordBase[segmentOrd + 1] == ordBase[segmentOrd]) {
-          // forward over empty segments which will share the same ordBase
-          segmentOrd++;
-        }
-        return raSubs.get(segmentOrd).vectorValue(unmappedOrd - ordBase[segmentOrd]);
-      }
-
-      @Override
-      public BytesRef binaryValue(int targetOrd) throws IOException {
-        throw new UnsupportedOperationException();
-      }
-    }
   }
 }
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
index 2ae1bc4..11fd80f 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -271,7 +271,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
   private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
     IndexInput bytesSlice =
         vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
-    return new OffHeapVectorValues(fieldEntry, bytesSlice);
+    return new OffHeapVectorValues(fieldEntry.dimension, fieldEntry.ordToDoc, bytesSlice);
   }
 
   private Bits getAcceptOrds(Bits acceptDocs, FieldEntry fieldEntry) {
@@ -354,10 +354,11 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
   }
 
   /** Read the vector values from the index input. This supports both iterated and random access. */
-  private static class OffHeapVectorValues extends VectorValues
+  public static class OffHeapVectorValues extends VectorValues
       implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
 
-    final FieldEntry fieldEntry;
+    final int dimension;
+    final int[] ordToDoc;
     final IndexInput dataIn;
 
     final BytesRef binaryValue;
@@ -368,23 +369,25 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
     int ord = -1;
     int doc = -1;
 
-    OffHeapVectorValues(FieldEntry fieldEntry, IndexInput dataIn) {
-      this.fieldEntry = fieldEntry;
+    OffHeapVectorValues(int dimension, int[] ordToDoc, IndexInput dataIn) {
+      this.dimension = dimension;
+      this.ordToDoc = ordToDoc;
       this.dataIn = dataIn;
-      byteSize = Float.BYTES * fieldEntry.dimension;
+
+      byteSize = Float.BYTES * dimension;
       byteBuffer = ByteBuffer.allocate(byteSize);
-      value = new float[fieldEntry.dimension];
+      value = new float[dimension];
       binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
     }
 
     @Override
     public int dimension() {
-      return fieldEntry.dimension;
+      return dimension;
     }
 
     @Override
     public int size() {
-      return fieldEntry.size();
+      return ordToDoc.length;
     }
 
     @Override
@@ -411,7 +414,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
       if (++ord >= size()) {
         doc = NO_MORE_DOCS;
       } else {
-        doc = fieldEntry.ordToDoc[ord];
+        doc = ordToDoc[ord];
       }
       return doc;
     }
@@ -419,27 +422,27 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
     @Override
     public int advance(int target) {
       assert docID() < target;
-      ord = Arrays.binarySearch(fieldEntry.ordToDoc, ord + 1, fieldEntry.ordToDoc.length, target);
+      ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target);
       if (ord < 0) {
         ord = -(ord + 1);
       }
-      assert ord <= fieldEntry.ordToDoc.length;
-      if (ord == fieldEntry.ordToDoc.length) {
+      assert ord <= ordToDoc.length;
+      if (ord == ordToDoc.length) {
         doc = NO_MORE_DOCS;
       } else {
-        doc = fieldEntry.ordToDoc[ord];
+        doc = ordToDoc[ord];
       }
       return doc;
     }
 
     @Override
     public long cost() {
-      return fieldEntry.size();
+      return ordToDoc.length;
     }
 
     @Override
     public RandomAccessVectorValues randomAccess() {
-      return new OffHeapVectorValues(fieldEntry, dataIn.clone());
+      return new OffHeapVectorValues(dimension, ordToDoc, dataIn.clone());
     }
 
     @Override
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
index f512407..f3f468a 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
@@ -26,11 +26,14 @@ import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.codecs.KnnVectorsWriter;
 import org.apache.lucene.index.FieldInfo;
 import org.apache.lucene.index.IndexFileNames;
+import org.apache.lucene.index.MergeState;
 import org.apache.lucene.index.RandomAccessVectorValuesProducer;
 import org.apache.lucene.index.SegmentWriteState;
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.store.IndexInput;
 import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.IOUtils;
 import org.apache.lucene.util.hnsw.HnswGraph;
@@ -110,26 +113,14 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
   @Override
   public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
       throws IOException {
+    long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
+
     VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
-    long pos = vectorData.getFilePointer();
-    // write floats aligned at 4 bytes. This will not survive CFS, but it shows a small benefit when
-    // CFS is not used, eg for larger indexes
-    long padding = (4 - (pos & 0x3)) & 0x3;
-    long vectorDataOffset = pos + padding;
-    for (int i = 0; i < padding; i++) {
-      vectorData.writeByte((byte) 0);
-    }
     // TODO - use a better data structure; a bitset? DocsWithFieldSet is p.p. in o.a.l.index
-    int[] docIds = new int[vectors.size()];
-    int count = 0;
-    for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) {
-      // write vector
-      writeVectorValue(vectors);
-      docIds[count] = docV;
-    }
-    // count may be < vectors.size() e,g, if some documents were deleted
-    long[] offsets = new long[count];
-    long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
+    int[] docIds = writeVectorData(vectorData, vectors);
+    assert vectors.size() == docIds.length;
+
+    long[] offsets = new long[docIds.length];
     long vectorIndexOffset = vectorIndex.getFilePointer();
     if (vectors instanceof RandomAccessVectorValuesProducer) {
       writeGraph(
@@ -138,13 +129,14 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
           fieldInfo.getVectorSimilarityFunction(),
           vectorIndexOffset,
           offsets,
-          count,
           maxConn,
           beamWidth);
     } else {
       throw new IllegalArgumentException(
           "Indexing an HNSW graph requires a random access vector values, got " + vectors);
     }
+
+    long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
     long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
     writeMeta(
         fieldInfo,
@@ -152,18 +144,132 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
         vectorDataLength,
         vectorIndexOffset,
         vectorIndexLength,
-        count,
         docIds);
     writeGraphOffsets(meta, offsets);
   }
 
+  @Override
+  public void merge(MergeState mergeState) throws IOException {
+    for (int i = 0; i < mergeState.fieldInfos.length; i++) {
+      KnnVectorsReader reader = mergeState.knnVectorsReaders[i];
+      assert reader != null || mergeState.fieldInfos[i].hasVectorValues() == false;
+      if (reader != null) {
+        reader.checkIntegrity();
+      }
+    }
+
+    for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
+      if (fieldInfo.hasVectorValues()) {
+        if (mergeState.infoStream.isEnabled("VV")) {
+          mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
+        }
+        mergeField(fieldInfo, mergeState);
+        if (mergeState.infoStream.isEnabled("VV")) {
+          mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
+        }
+      }
+    }
+    finish();
+  }
+
+  private void mergeField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+    if (mergeState.infoStream.isEnabled("VV")) {
+      mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
+    }
+
+    long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
+
+    VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
+    IndexOutput tempVectorData =
+        segmentWriteState.directory.createTempOutput(
+            vectorData.getName(), "temp", segmentWriteState.context);
+    IndexInput vectorDataInput = null;
+    boolean success = false;
+    try {
+      // write the merged vector data to a temporary file
+      int[] docIds = writeVectorData(tempVectorData, vectors);
+      CodecUtil.writeFooter(tempVectorData);
+      IOUtils.close(tempVectorData);
+
+      // copy the temporary file vectors to the actual data file
+      vectorDataInput =
+          segmentWriteState.directory.openInput(
+              tempVectorData.getName(), segmentWriteState.context);
+      vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
+      CodecUtil.retrieveChecksum(vectorDataInput);
+
+      // build the graph using the temporary vector data
+      Lucene90HnswVectorsReader.OffHeapVectorValues offHeapVectors =
+          new Lucene90HnswVectorsReader.OffHeapVectorValues(
+              vectors.dimension(), docIds, vectorDataInput);
+
+      long[] offsets = new long[docIds.length];
+      long vectorIndexOffset = vectorIndex.getFilePointer();
+      writeGraph(
+          vectorIndex,
+          offHeapVectors,
+          fieldInfo.getVectorSimilarityFunction(),
+          vectorIndexOffset,
+          offsets,
+          maxConn,
+          beamWidth);
+
+      long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
+      long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
+      writeMeta(
+          fieldInfo,
+          vectorDataOffset,
+          vectorDataLength,
+          vectorIndexOffset,
+          vectorIndexLength,
+          docIds);
+      writeGraphOffsets(meta, offsets);
+      success = true;
+    } finally {
+      IOUtils.close(vectorDataInput);
+      if (success) {
+        segmentWriteState.directory.deleteFile(tempVectorData.getName());
+      } else {
+        IOUtils.closeWhileHandlingException(tempVectorData);
+        IOUtils.deleteFilesIgnoringExceptions(
+            segmentWriteState.directory, tempVectorData.getName());
+      }
+    }
+
+    if (mergeState.infoStream.isEnabled("VV")) {
+      mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
+    }
+  }
+
+  /**
+   * Writes the vector values to the output and returns a mapping from dense ordinals to document
+   * IDs. The length of the returned array matches the total number of documents with a vector
+   * (which excludes deleted documents), so it may be less than {@link VectorValues#size()}.
+   */
+  private static int[] writeVectorData(IndexOutput output, VectorValues vectors)
+      throws IOException {
+    int[] docIds = new int[vectors.size()];
+    int count = 0;
+    for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) {
+      // write vector
+      BytesRef binaryValue = vectors.binaryValue();
+      assert binaryValue.length == vectors.dimension() * Float.BYTES;
+      output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
+      docIds[count] = docV;
+    }
+
+    if (docIds.length > count) {
+      return ArrayUtil.copyOfSubArray(docIds, 0, count);
+    }
+    return docIds;
+  }
+
   private void writeMeta(
       FieldInfo field,
       long vectorDataOffset,
       long vectorDataLength,
       long indexDataOffset,
       long indexDataLength,
-      int size,
       int[] docIds)
       throws IOException {
     meta.writeInt(field.number);
@@ -173,20 +279,13 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
     meta.writeVLong(indexDataOffset);
     meta.writeVLong(indexDataLength);
     meta.writeInt(field.getVectorDimension());
-    meta.writeInt(size);
-    for (int i = 0; i < size; i++) {
+    meta.writeInt(docIds.length);
+    for (int docId : docIds) {
       // TODO: delta-encode, or write as bitset
-      meta.writeVInt(docIds[i]);
+      meta.writeVInt(docId);
     }
   }
 
-  private void writeVectorValue(VectorValues vectors) throws IOException {
-    // write vector value
-    BytesRef binaryValue = vectors.binaryValue();
-    assert binaryValue.length == vectors.dimension() * Float.BYTES;
-    vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
-  }
-
   private void writeGraphOffsets(IndexOutput out, long[] offsets) throws IOException {
     long last = 0;
     for (long offset : offsets) {
@@ -201,7 +300,6 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
       VectorSimilarityFunction similarityFunction,
       long graphDataOffset,
       long[] offsets,
-      int count,
       int maxConn,
       int beamWidth)
       throws IOException {
@@ -211,7 +309,7 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
     hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
     HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
 
-    for (int ord = 0; ord < count; ord++) {
+    for (int ord = 0; ord < offsets.length; ord++) {
       // write graph
       offsets[ord] = graphData.getFilePointer() - graphDataOffset;
 
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
index ee2f931..262fd22 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
@@ -19,7 +19,10 @@ package org.apache.lucene.codecs.perfield;
 
 import java.io.Closeable;
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
 import java.util.HashMap;
+import java.util.IdentityHashMap;
 import java.util.Map;
 import java.util.ServiceLoader;
 import java.util.TreeMap;
@@ -27,6 +30,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat;
 import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.codecs.KnnVectorsWriter;
 import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.MergeState;
 import org.apache.lucene.index.SegmentReadState;
 import org.apache.lucene.index.SegmentWriteState;
 import org.apache.lucene.index.VectorValues;
@@ -104,6 +108,31 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
     }
 
     @Override
+    public final void merge(MergeState mergeState) throws IOException {
+      Map<KnnVectorsWriter, Collection<String>> writersToFields = new IdentityHashMap<>();
+
+      // Group each writer by the fields it handles
+      for (FieldInfo fi : mergeState.mergeFieldInfos) {
+        if (fi.hasVectorValues() == false) {
+          continue;
+        }
+        KnnVectorsWriter writer = getInstance(fi);
+        Collection<String> fields = writersToFields.computeIfAbsent(writer, k -> new ArrayList<>());
+        fields.add(fi.name);
+      }
+
+      // Delegate the merge to the appropriate writer
+      PerFieldMergeState pfMergeState = new PerFieldMergeState(mergeState);
+      try {
+        for (Map.Entry<KnnVectorsWriter, Collection<String>> e : writersToFields.entrySet()) {
+          e.getKey().merge(pfMergeState.apply(e.getValue()));
+        }
+      } finally {
+        pfMergeState.reset();
+      }
+    }
+
+    @Override
     public void finish() throws IOException {
       for (WriterAndSuffix was : formats.values()) {
         was.writer.finish();
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
index 8584cc3..e8a46f9 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
@@ -36,6 +36,7 @@ import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.MergeState;
 import org.apache.lucene.index.NoMergePolicy;
 import org.apache.lucene.index.SegmentReadState;
 import org.apache.lucene.index.SegmentWriteState;
@@ -178,6 +179,14 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
         }
 
         @Override
+        public void merge(MergeState mergeState) throws IOException {
+          for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
+            fieldsWritten.add(fieldInfo.name);
+          }
+          writer.merge(mergeState);
+        }
+
+        @Override
         public void finish() throws IOException {
           writer.finish();
         }
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
index a38b19c..a4a60db 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
@@ -23,6 +23,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.codecs.KnnVectorsWriter;
 import org.apache.lucene.index.FieldInfo;
 import org.apache.lucene.index.FieldInfos;
+import org.apache.lucene.index.MergeState;
 import org.apache.lucene.index.SegmentReadState;
 import org.apache.lucene.index.SegmentWriteState;
 import org.apache.lucene.index.VectorValues;
@@ -70,6 +71,11 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
     }
 
     @Override
+    public void merge(MergeState mergeState) throws IOException {
+      delegate.merge(mergeState);
+    }
+
+    @Override
     public void finish() throws IOException {
       delegate.finish();
     }