You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ju...@apache.org on 2022/01/18 22:10:57 UTC

[lucene] branch branch_9x updated: LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues (#534)

This is an automated email from the ASF dual-hosted git repository.

julietibs pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new af3a0bc  LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues (#534)
af3a0bc is described below

commit af3a0bc4d5669d031618e163e3f1f724e33dfddc
Author: zacharymorn <za...@yahoo.com>
AuthorDate: Thu Jan 6 22:14:41 2022 -0800

    LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues (#534)
---
 lucene/CHANGES.txt                                 |   3 +
 .../simpletext/SimpleTextKnnVectorsWriter.java     |   5 +-
 .../org/apache/lucene/codecs/KnnVectorsWriter.java | 113 +++++++++++++--------
 .../codecs/lucene90/Lucene90HnswVectorsWriter.java |   5 +-
 .../codecs/perfield/PerFieldKnnVectorsFormat.java  |   5 +-
 .../apache/lucene/index/VectorValuesWriter.java    |  42 ++++++--
 .../perfield/TestPerFieldKnnVectorsFormat.java     |   6 +-
 .../asserting/AssertingKnnVectorsFormat.java       |  11 +-
 8 files changed, 133 insertions(+), 57 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 664dfb1..4688132 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -18,6 +18,9 @@ API Changes
   org.apache.lucene.* to org.apache.lucene.tests.* to avoid package name conflicts with the
   core module. (Dawid Weiss)
 
+* LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues.
+  (Zach Chen, Michael Sokolov, Julie Tibshirani, Adrien Grand)
+
 * LUCENE-10335: Deprecate helper methods for resource loading in IOUtils and StopwordAnalyzerBase
   that are not compatible with module system (Class#getResourceAsStream() and Class#getResource()
   are caller sensitive in Java 11). Instead add utility method IOUtils#requireResourceNonNull(T)
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java
index 270e9db..8b527e0 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java
@@ -23,6 +23,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.codecs.KnnVectorsWriter;
 import org.apache.lucene.index.FieldInfo;
 import org.apache.lucene.index.IndexFileNames;
@@ -74,7 +75,9 @@ public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter {
   }
 
   @Override
-  public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
+  public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
+      throws IOException {
+    VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
     long vectorDataOffset = vectorData.getFilePointer();
     List<Integer> docIds = new ArrayList<>();
     int docV;
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
index cd104c4..4afa933 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
@@ -31,6 +31,8 @@ import org.apache.lucene.index.RandomAccessVectorValues;
 import org.apache.lucene.index.RandomAccessVectorValuesProducer;
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.BytesRef;
 
 /** Writes vectors to an index. */
@@ -40,7 +42,8 @@ public abstract class KnnVectorsWriter implements Closeable {
   protected KnnVectorsWriter() {}
 
   /** Write all values contained in the provided reader */
-  public abstract void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException;
+  public abstract void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
+      throws IOException;
 
   /** Called once at the end before close */
   public abstract void finish() throws IOException;
@@ -67,47 +70,77 @@ public abstract class KnnVectorsWriter implements Closeable {
     if (mergeState.infoStream.isEnabled("VV")) {
       mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
     }
-    List<VectorValuesSub> subs = new ArrayList<>();
-    int dimension = -1;
-    VectorSimilarityFunction similarityFunction = null;
-    int nonEmptySegmentIndex = 0;
-    for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
-      KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
-      if (knnVectorsReader != null) {
-        if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
-          int segmentDimension = mergeFieldInfo.getVectorDimension();
-          VectorSimilarityFunction segmentSimilarityFunction =
-              mergeFieldInfo.getVectorSimilarityFunction();
-          if (dimension == -1) {
-            dimension = segmentDimension;
-            similarityFunction = mergeFieldInfo.getVectorSimilarityFunction();
-          } else if (dimension != segmentDimension) {
-            throw new IllegalStateException(
-                "Varying dimensions for vector-valued field "
-                    + mergeFieldInfo.name
-                    + ": "
-                    + dimension
-                    + "!="
-                    + segmentDimension);
-          } else if (similarityFunction != segmentSimilarityFunction) {
-            throw new IllegalStateException(
-                "Varying similarity functions for vector-valued field "
-                    + mergeFieldInfo.name
-                    + ": "
-                    + similarityFunction
-                    + "!="
-                    + segmentSimilarityFunction);
-          }
-          VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name);
-          if (values != null) {
-            subs.add(new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values));
-          }
-        }
-      }
-    }
     // Create a new VectorValues by iterating over the sub vectors, mapping the resulting
     // docids using docMaps in the mergeState.
-    writeField(mergeFieldInfo, new VectorValuesMerger(subs, mergeState));
+    writeField(
+        mergeFieldInfo,
+        new KnnVectorsReader() {
+          @Override
+          public long ramBytesUsed() {
+            return 0;
+          }
+
+          @Override
+          public void close() throws IOException {
+            throw new UnsupportedOperationException();
+          }
+
+          @Override
+          public void checkIntegrity() throws IOException {
+            throw new UnsupportedOperationException();
+          }
+
+          @Override
+          public VectorValues getVectorValues(String field) throws IOException {
+            List<VectorValuesSub> subs = new ArrayList<>();
+            int dimension = -1;
+            VectorSimilarityFunction similarityFunction = null;
+            int nonEmptySegmentIndex = 0;
+            for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+              KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+              if (knnVectorsReader != null) {
+                if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
+                  int segmentDimension = mergeFieldInfo.getVectorDimension();
+                  VectorSimilarityFunction segmentSimilarityFunction =
+                      mergeFieldInfo.getVectorSimilarityFunction();
+                  if (dimension == -1) {
+                    dimension = segmentDimension;
+                    similarityFunction = mergeFieldInfo.getVectorSimilarityFunction();
+                  } else if (dimension != segmentDimension) {
+                    throw new IllegalStateException(
+                        "Varying dimensions for vector-valued field "
+                            + mergeFieldInfo.name
+                            + ": "
+                            + dimension
+                            + "!="
+                            + segmentDimension);
+                  } else if (similarityFunction != segmentSimilarityFunction) {
+                    throw new IllegalStateException(
+                        "Varying similarity functions for vector-valued field "
+                            + mergeFieldInfo.name
+                            + ": "
+                            + similarityFunction
+                            + "!="
+                            + segmentSimilarityFunction);
+                  }
+                  VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name);
+                  if (values != null) {
+                    subs.add(
+                        new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values));
+                  }
+                }
+              }
+            }
+            return new VectorValuesMerger(subs, mergeState);
+          }
+
+          @Override
+          public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
+              throws IOException {
+            throw new UnsupportedOperationException();
+          }
+        });
+
     if (mergeState.infoStream.isEnabled("VV")) {
       mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
     }
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
index 0c2832b..f512407 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
@@ -22,6 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
 import java.io.IOException;
 import java.util.Arrays;
 import org.apache.lucene.codecs.CodecUtil;
+import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.codecs.KnnVectorsWriter;
 import org.apache.lucene.index.FieldInfo;
 import org.apache.lucene.index.IndexFileNames;
@@ -107,7 +108,9 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
   }
 
   @Override
-  public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
+  public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
+      throws IOException {
+    VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
     long pos = vectorData.getFilePointer();
     // write floats aligned at 4 bytes. This will not survive CFS, but it shows a small benefit when
     // CFS is not used, eg for larger indexes
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
index 1ec03da..ee2f931 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
@@ -98,8 +98,9 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
     }
 
     @Override
-    public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
-      getInstance(fieldInfo).writeField(fieldInfo, values);
+    public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
+        throws IOException {
+      getInstance(fieldInfo).writeField(fieldInfo, knnVectorsReader);
     }
 
     @Override
diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java b/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java
index 4b403a3..673f39a 100644
--- a/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java
@@ -22,9 +22,12 @@ import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.codecs.KnnVectorsWriter;
 import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.Counter;
 import org.apache.lucene.util.RamUsageEstimator;
@@ -107,13 +110,38 @@ class VectorValuesWriter {
    * @throws IOException if there is an error writing the field and its values
    */
   public void flush(Sorter.DocMap sortMap, KnnVectorsWriter knnVectorsWriter) throws IOException {
-    VectorValues vectorValues =
-        new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
-    if (sortMap != null) {
-      knnVectorsWriter.writeField(fieldInfo, new SortingVectorValues(vectorValues, sortMap));
-    } else {
-      knnVectorsWriter.writeField(fieldInfo, vectorValues);
-    }
+    KnnVectorsReader knnVectorsReader =
+        new KnnVectorsReader() {
+          @Override
+          public long ramBytesUsed() {
+            return 0;
+          }
+
+          @Override
+          public void close() throws IOException {
+            throw new UnsupportedOperationException();
+          }
+
+          @Override
+          public void checkIntegrity() throws IOException {
+            throw new UnsupportedOperationException();
+          }
+
+          @Override
+          public VectorValues getVectorValues(String field) throws IOException {
+            VectorValues vectorValues =
+                new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
+            return sortMap != null ? new SortingVectorValues(vectorValues, sortMap) : vectorValues;
+          }
+
+          @Override
+          public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
+              throws IOException {
+            throw new UnsupportedOperationException();
+          }
+        };
+
+    knnVectorsWriter.writeField(fieldInfo, knnVectorsReader);
   }
 
   static class SortingVectorValues extends VectorValues
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
index 1181247..8584cc3 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java
@@ -39,7 +39,6 @@ import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.NoMergePolicy;
 import org.apache.lucene.index.SegmentReadState;
 import org.apache.lucene.index.SegmentWriteState;
-import org.apache.lucene.index.VectorValues;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.analysis.MockAnalyzer;
@@ -172,9 +171,10 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
       KnnVectorsWriter writer = delegate.fieldsWriter(state);
       return new KnnVectorsWriter() {
         @Override
-        public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
+        public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
+            throws IOException {
           fieldsWritten.add(fieldInfo.name);
-          writer.writeField(fieldInfo, values);
+          writer.writeField(fieldInfo, knnVectorsReader);
         }
 
         @Override
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
index 55b17d5..a38b19c 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
@@ -58,10 +58,15 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
     }
 
     @Override
-    public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
+    public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
+        throws IOException {
       assert fieldInfo != null;
-      assert values != null;
-      delegate.writeField(fieldInfo, values);
+      assert knnVectorsReader != null;
+      // assert that knnVectorsReader#getVectorValues returns different instances upon repeated
+      // calls
+      assert knnVectorsReader.getVectorValues(fieldInfo.name)
+          != knnVectorsReader.getVectorValues(fieldInfo.name);
+      delegate.writeField(fieldInfo, knnVectorsReader);
     }
 
     @Override