You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2023/01/11 15:04:58 UTC
[lucene] branch branch_9x updated: Create new KnnByteVectorField and KnnVectorsReader#getByteVectorValues(String) (#12064) (#12075)
This is an automated email from the ASF dual-hosted git repository.
jpountz pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/branch_9x by this push:
new 814577217bc Create new KnnByteVectorField and KnnVectorsReader#getByteVectorValues(String) (#12064) (#12075)
814577217bc is described below
commit 814577217bc3fd5343fa0fd7304d7e88e2e3f236
Author: Benjamin Trent <be...@gmail.com>
AuthorDate: Wed Jan 11 10:04:51 2023 -0500
Create new KnnByteVectorField and KnnVectorsReader#getByteVectorValues(String) (#12064) (#12075)
---
.../lucene90/Lucene90HnswGraphBuilder.java | 11 +-
.../lucene90/Lucene90HnswVectorsReader.java | 22 +-
.../lucene90/Lucene90OnHeapHnswGraph.java | 2 +-
.../lucene91/Lucene91HnswVectorsReader.java | 22 +-
.../lucene92/Lucene92HnswVectorsReader.java | 6 +
.../lucene92/OffHeapVectorValues.java | 25 +-
.../lucene94/Lucene94HnswVectorsReader.java | 32 +-
...torValues.java => OffHeapByteVectorValues.java} | 71 +---
.../lucene94/OffHeapVectorValues.java | 25 +-
.../lucene90/Lucene90HnswVectorsWriter.java | 2 +-
.../lucene91/Lucene91HnswGraphBuilder.java | 11 +-
.../lucene91/Lucene91HnswVectorsWriter.java | 2 +-
.../lucene92/Lucene92HnswVectorsWriter.java | 4 +-
.../lucene94/Lucene94HnswVectorsWriter.java | 140 +++++--
.../simpletext/SimpleTextKnnVectorsReader.java | 135 +++++-
.../lucene/codecs/BufferingKnnVectorsWriter.java | 35 +-
.../lucene/codecs/KnnFieldVectorsWriter.java | 2 +-
.../org/apache/lucene/codecs/KnnVectorsFormat.java | 6 +
.../org/apache/lucene/codecs/KnnVectorsReader.java | 8 +
.../org/apache/lucene/codecs/KnnVectorsWriter.java | 233 ++++++++---
.../codecs/lucene95/ExpandingVectorValues.java | 47 ---
.../codecs/lucene95/Lucene95HnswVectorsReader.java | 32 +-
.../codecs/lucene95/Lucene95HnswVectorsWriter.java | 144 ++++---
...torValues.java => OffHeapByteVectorValues.java} | 81 +---
.../codecs/lucene95/OffHeapVectorValues.java | 41 +-
.../codecs/perfield/PerFieldKnnVectorsFormat.java | 11 +
...KnnVectorField.java => KnnByteVectorField.java} | 111 ++---
.../org/apache/lucene/document/KnnVectorField.java | 80 +---
.../{VectorValues.java => ByteVectorValues.java} | 14 +-
.../java/org/apache/lucene/index/CheckIndex.java | 165 +++++---
.../java/org/apache/lucene/index/CodecReader.java | 18 +-
.../apache/lucene/index/DocValuesLeafReader.java | 5 +
.../lucene/index/ExitableDirectoryReader.java | 104 ++++-
.../org/apache/lucene/index/FilterLeafReader.java | 5 +
.../org/apache/lucene/index/IndexingChain.java | 29 +-
.../java/org/apache/lucene/index/LeafReader.java | 8 +
.../apache/lucene/index/ParallelLeafReader.java | 7 +
.../lucene/index/SlowCodecReaderWrapper.java | 5 +
.../apache/lucene/index/SortingCodecReader.java | 96 +++--
.../java/org/apache/lucene/index/VectorValues.java | 4 +-
.../org/apache/lucene/search/FieldExistsQuery.java | 30 +-
.../apache/lucene/search/KnnByteVectorQuery.java | 6 +-
.../org/apache/lucene/search/VectorScorer.java | 59 ++-
.../apache/lucene/util/hnsw/HnswGraphBuilder.java | 51 ++-
.../apache/lucene/util/hnsw/HnswGraphSearcher.java | 14 +-
.../lucene/util/hnsw/RandomAccessVectorValues.java | 19 +-
.../test/org/apache/lucene/document/TestField.java | 10 +-
.../lucene/index/TestSegmentToThreadMapping.java | 5 +
.../lucene/search/TestKnnByteVectorQuery.java | 12 +-
.../org/apache/lucene/search/TestVectorScorer.java | 3 +-
...orValues.java => AbstractMockVectorValues.java} | 68 +---
.../{TestHnswGraph.java => HnswGraphTestCase.java} | 451 ++++++++++-----------
.../apache/lucene/util/hnsw/KnnGraphTester.java | 15 +-
.../lucene/util/hnsw/MockByteVectorValues.java | 70 ++++
.../apache/lucene/util/hnsw/MockVectorValues.java | 83 +---
.../lucene/util/hnsw/TestHnswByteVectorGraph.java | 118 ++++++
.../lucene/util/hnsw/TestHnswFloatVectorGraph.java | 133 ++++++
.../search/highlight/TermVectorLeafReader.java | 6 +
.../apache/lucene/index/memory/MemoryIndex.java | 5 +
.../asserting/AssertingKnnVectorsFormat.java | 19 +-
.../tests/index/BaseKnnVectorsFormatTestCase.java | 106 +++--
.../lucene/tests/index/MergeReaderWrapper.java | 6 +
.../org/apache/lucene/tests/search/QueryUtils.java | 6 +
63 files changed, 1914 insertions(+), 1182 deletions(-)
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
index f0fa3fe41e3..35ba1f8a1cd 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
@@ -48,7 +48,7 @@ public final class Lucene90HnswGraphBuilder {
private final Lucene90NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction;
- private final RandomAccessVectorValues vectorValues;
+ private final RandomAccessVectorValues<float[]> vectorValues;
private final SplittableRandom random;
private final Lucene90BoundsChecker bound;
final Lucene90OnHeapHnswGraph hnsw;
@@ -57,7 +57,7 @@ public final class Lucene90HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
- private final RandomAccessVectorValues buildVectors;
+ private final RandomAccessVectorValues<float[]> buildVectors;
/**
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
@@ -72,7 +72,7 @@ public final class Lucene90HnswGraphBuilder {
* to ensure repeatable construction.
*/
public Lucene90HnswGraphBuilder(
- RandomAccessVectorValues vectors,
+ RandomAccessVectorValues<float[]> vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
@@ -103,7 +103,8 @@ public final class Lucene90HnswGraphBuilder {
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors
*/
- public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
+ public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues<float[]> vectors)
+ throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@@ -229,7 +230,7 @@ public final class Lucene90HnswGraphBuilder {
float[] candidate,
float score,
Lucene90NeighborArray neighbors,
- RandomAccessVectorValues vectorValues)
+ RandomAccessVectorValues<float[]> vectorValues)
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
index ec282a6d151..f33aa13e09f 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -27,6 +27,7 @@ import java.util.Map;
import java.util.SplittableRandom;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
@@ -232,6 +233,11 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return getOffHeapVectorValues(fieldEntry);
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
@@ -352,7 +358,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
/** Read the vector values from the index input. This supports both iterated and random access. */
- static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
+ static class OffHeapVectorValues extends VectorValues
+ implements RandomAccessVectorValues<float[]> {
final int dimension;
final int[] ordToDoc;
@@ -433,7 +440,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
@Override
- public RandomAccessVectorValues copy() {
+ public RandomAccessVectorValues<float[]> copy() {
return new OffHeapVectorValues(dimension, ordToDoc, dataIn.clone());
}
@@ -443,17 +450,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
dataIn.readFloats(value, 0, value.length);
return value;
}
-
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- readValue(targetOrd);
- return binaryValue;
- }
-
- private void readValue(int targetOrd) throws IOException {
- dataIn.seek((long) targetOrd * byteSize);
- dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
- }
}
/** Read the nearest-neighbors graph from the index input */
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
index e3b15bf22b1..52dae288caa 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
@@ -74,7 +74,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
float[] query,
int topK,
int numSeed,
- RandomAccessVectorValues vectors,
+ RandomAccessVectorValues<float[]> vectors,
VectorSimilarityFunction similarityFunction,
HnswGraph graphValues,
Bits acceptOrds,
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
index 8160327ef53..6367907f8c5 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
@@ -27,6 +27,7 @@ import java.util.Map;
import java.util.function.IntUnaryOperator;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
@@ -224,6 +225,11 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
return getOffHeapVectorValues(fieldEntry);
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
@@ -398,7 +404,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
/** Read the vector values from the index input. This supports both iterated and random access. */
- static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
+ static class OffHeapVectorValues extends VectorValues
+ implements RandomAccessVectorValues<float[]> {
private final int dimension;
private final int size;
@@ -486,7 +493,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
@Override
- public RandomAccessVectorValues copy() {
+ public RandomAccessVectorValues<float[]> copy() {
return new OffHeapVectorValues(dimension, size, ordToDoc, dataIn.clone());
}
@@ -496,17 +503,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
dataIn.readFloats(value, 0, value.length);
return value;
}
-
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- readValue(targetOrd);
- return binaryValue;
- }
-
- private void readValue(int targetOrd) throws IOException {
- dataIn.seek((long) targetOrd * byteSize);
- dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
- }
}
/** Read the nearest-neighbors graph from the index input */
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
index 8cc3ad694df..bc975b3dde1 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
@@ -219,6 +220,11 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
return OffHeapVectorValues.load(fieldEntry, vectorData);
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapVectorValues.java
index b1a0eb5b5ea..8ddf86cacfa 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapVectorValues.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapVectorValues.java
@@ -29,7 +29,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
-abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
+abstract class OffHeapVectorValues extends VectorValues
+ implements RandomAccessVectorValues<float[]> {
protected final int dimension;
protected final int size;
@@ -66,17 +67,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
return value;
}
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- readValue(targetOrd);
- return binaryValue;
- }
-
- private void readValue(int targetOrd) throws IOException {
- slice.seek((long) targetOrd * byteSize);
- slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
- }
-
public abstract int ordToDoc(int ord);
static OffHeapVectorValues load(
@@ -137,7 +127,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<float[]> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
}
@@ -210,7 +200,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<float[]> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
}
@@ -282,7 +272,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<float[]> copy() throws IOException {
throw new UnsupportedOperationException();
}
@@ -291,11 +281,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
throw new UnsupportedOperationException();
}
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- throw new UnsupportedOperationException();
- }
-
@Override
public int ordToDoc(int ord) {
throw new UnsupportedOperationException();
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
index 2cb03a4906b..9920f0f7b87 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
@@ -241,12 +242,31 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
@Override
public VectorValues getVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
- VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData);
- if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) {
- return new ExpandingVectorValues(values);
- } else {
- return values;
+ if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
+ throw new IllegalArgumentException(
+ "field=\""
+ + field
+ + "\" is encoded as: "
+ + fieldEntry.vectorEncoding
+ + " expected: "
+ + VectorEncoding.FLOAT32);
}
+ return OffHeapVectorValues.load(fieldEntry, vectorData);
+ }
+
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ FieldEntry fieldEntry = fields.get(field);
+ if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
+ throw new IllegalArgumentException(
+ "field=\""
+ + field
+ + "\" is encoded as: "
+ + fieldEntry.vectorEncoding
+ + " expected: "
+ + VectorEncoding.FLOAT32);
+ }
+ return OffHeapByteVectorValues.load(fieldEntry, vectorData);
}
@Override
@@ -300,7 +320,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());
- OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
+ OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
NeighborQueue results =
HnswGraphSearcher.search(
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java
similarity index 78%
copy from lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapVectorValues.java
copy to lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java
index 0705f8cf512..2ac67212b6d 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapVectorValues.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java
@@ -20,7 +20,8 @@ package org.apache.lucene.backward_codecs.lucene94;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
-import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
@@ -29,7 +30,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
-abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
+abstract class OffHeapByteVectorValues extends ByteVectorValues
+ implements RandomAccessVectorValues<BytesRef> {
protected final int dimension;
protected final int size;
@@ -37,15 +39,13 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
protected final BytesRef binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
- protected final float[] value;
- OffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
+ OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.byteSize = byteSize;
byteBuffer = ByteBuffer.allocate(byteSize);
- value = new float[dimension];
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
@@ -60,14 +60,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public float[] vectorValue(int targetOrd) throws IOException {
- slice.seek((long) targetOrd * byteSize);
- slice.readFloats(value, 0, value.length);
- return value;
- }
-
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
+ public BytesRef vectorValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
@@ -79,14 +72,14 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
public abstract int ordToDoc(int ord);
- static OffHeapVectorValues load(
+ static OffHeapByteVectorValues load(
Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
- if (fieldEntry.docsWithFieldOffset == -2) {
+ if (fieldEntry.docsWithFieldOffset == -2 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
}
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
- int byteSize = fieldEntry.dimension * fieldEntry.vectorEncoding.byteSize;
+ int byteSize = fieldEntry.dimension;
if (fieldEntry.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues(
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
@@ -97,7 +90,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
abstract Bits getAcceptOrds(Bits acceptDocs);
- static class DenseOffHeapVectorValues extends OffHeapVectorValues {
+ static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
private int doc = -1;
@@ -106,16 +99,9 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public float[] vectorValue() throws IOException {
+ public BytesRef vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
- slice.readFloats(value, 0, value.length);
- return value;
- }
-
- @Override
- public BytesRef binaryValue() throws IOException {
- slice.seek((long) doc * byteSize);
- slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
+ slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
return binaryValue;
}
@@ -139,7 +125,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<BytesRef> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@@ -154,7 +140,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
}
- private static class SparseOffHeapVectorValues extends OffHeapVectorValues {
+ private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
private final DirectMonotonicReader ordToDoc;
private final IndexedDISI disi;
// dataIn was used to init a new IndexedDIS for #randomAccess()
@@ -185,14 +171,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public float[] vectorValue() throws IOException {
- slice.seek((long) (disi.index()) * byteSize);
- slice.readFloats(value, 0, value.length);
- return value;
- }
-
- @Override
- public BytesRef binaryValue() throws IOException {
+ public BytesRef vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
return binaryValue;
@@ -215,7 +194,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<BytesRef> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@@ -243,7 +222,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
}
- private static class EmptyOffHeapVectorValues extends OffHeapVectorValues {
+ private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null, 0);
@@ -262,12 +241,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public float[] vectorValue() throws IOException {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public BytesRef binaryValue() throws IOException {
+ public BytesRef vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@@ -287,17 +261,12 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public float[] vectorValue(int targetOrd) throws IOException {
+ public RandomAccessVectorValues<BytesRef> copy() throws IOException {
throw new UnsupportedOperationException();
}
@Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
+ public BytesRef vectorValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapVectorValues.java
index 0705f8cf512..e10fa45ce62 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapVectorValues.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapVectorValues.java
@@ -29,7 +29,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
-abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
+abstract class OffHeapVectorValues extends VectorValues
+ implements RandomAccessVectorValues<float[]> {
protected final int dimension;
protected final int size;
@@ -66,17 +67,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
return value;
}
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- readValue(targetOrd);
- return binaryValue;
- }
-
- private void readValue(int targetOrd) throws IOException {
- slice.seek((long) targetOrd * byteSize);
- slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
- }
-
public abstract int ordToDoc(int ord);
static OffHeapVectorValues load(
@@ -139,7 +129,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<float[]> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@@ -215,7 +205,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<float[]> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@@ -287,7 +277,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<float[]> copy() throws IOException {
throw new UnsupportedOperationException();
}
@@ -296,11 +286,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
throw new UnsupportedOperationException();
}
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- throw new UnsupportedOperationException();
- }
-
@Override
public int ordToDoc(int ord) {
throw new UnsupportedOperationException();
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java
index 6478152e104..4952029e947 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java
@@ -224,7 +224,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
private void writeGraph(
IndexOutput graphData,
- RandomAccessVectorValues vectorValues,
+ RandomAccessVectorValues<float[]> vectorValues,
VectorSimilarityFunction similarityFunction,
long graphDataOffset,
long[] offsets,
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
index 5f37e905e2f..75f2aa4b659 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
@@ -53,7 +53,7 @@ public final class Lucene91HnswGraphBuilder {
private final Lucene91NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction;
- private final RandomAccessVectorValues vectorValues;
+ private final RandomAccessVectorValues<float[]> vectorValues;
private final SplittableRandom random;
private final Lucene91BoundsChecker bound;
private final HnswGraphSearcher<float[]> graphSearcher;
@@ -64,7 +64,7 @@ public final class Lucene91HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
- private RandomAccessVectorValues buildVectors;
+ private RandomAccessVectorValues<float[]> buildVectors;
/**
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
@@ -79,7 +79,7 @@ public final class Lucene91HnswGraphBuilder {
* to ensure repeatable construction.
*/
public Lucene91HnswGraphBuilder(
- RandomAccessVectorValues vectors,
+ RandomAccessVectorValues<float[]> vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
@@ -119,7 +119,8 @@ public final class Lucene91HnswGraphBuilder {
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors
*/
- public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
+ public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues<float[]> vectors)
+ throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@@ -250,7 +251,7 @@ public final class Lucene91HnswGraphBuilder {
float[] candidate,
float score,
Lucene91NeighborArray neighbors,
- RandomAccessVectorValues vectorValues)
+ RandomAccessVectorValues<float[]> vectorValues)
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java
index 5e9b62c8dd7..2c603aba621 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java
@@ -233,7 +233,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
}
private Lucene91OnHeapHnswGraph writeGraph(
- RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
+ RandomAccessVectorValues<float[]> vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
// build graph
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java
index 0e6c72186c7..df0c1032064 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java
@@ -268,13 +268,13 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
}
private OnHeapHnswGraph writeGraph(
- RandomAccessVectorValues vectorValues,
+ RandomAccessVectorValues<float[]> vectorValues,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException {
// build graph
- HnswGraphBuilder<?> hnswGraphBuilder =
+ HnswGraphBuilder<float[]> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
vectorEncoding,
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java
index 7e2815b4ba6..70318077b41 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java
@@ -30,12 +30,14 @@ import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
@@ -389,8 +391,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
- VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
-
IndexOutput tempVectorData =
segmentWriteState.directory.createTempOutput(
vectorData.getName(), "temp", segmentWriteState.context);
@@ -398,8 +398,22 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
boolean success = false;
try {
// write the vector data to a temporary file
- DocsWithFieldSet docsWithField =
- writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize);
+ final DocsWithFieldSet docsWithField;
+ switch (fieldInfo.getVectorEncoding()) {
+ case BYTE:
+ docsWithField =
+ writeByteVectorData(
+ tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
+ break;
+ case FLOAT32:
+ docsWithField =
+ writeVectorData(
+ tempVectorData, MergedVectorValues.mergeVectorValues(fieldInfo, mergeState));
+ break;
+ default:
+ throw new IllegalArgumentException(
+ "unknown vector encoding=" + fieldInfo.getVectorEncoding());
+ }
CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData);
@@ -415,23 +429,51 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
// we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
// doesn't need to know docIds
// TODO: separate random access vector values from DocIdSetIterator?
- int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize;
- OffHeapVectorValues offHeapVectors =
- new OffHeapVectorValues.DenseOffHeapVectorValues(
- vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
+ int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
OnHeapHnswGraph graph = null;
- if (offHeapVectors.size() != 0) {
+ if (docsWithField.cardinality() != 0) {
// build graph
- HnswGraphBuilder<?> hnswGraphBuilder =
- HnswGraphBuilder.create(
- offHeapVectors,
- fieldInfo.getVectorEncoding(),
- fieldInfo.getVectorSimilarityFunction(),
- M,
- beamWidth,
- HnswGraphBuilder.randSeed);
- hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
- graph = hnswGraphBuilder.build(offHeapVectors.copy());
+ switch (fieldInfo.getVectorEncoding()) {
+ case BYTE:
+ OffHeapByteVectorValues.DenseOffHeapVectorValues byteVectorValues =
+ new OffHeapByteVectorValues.DenseOffHeapVectorValues(
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality(),
+ vectorDataInput,
+ byteSize);
+ HnswGraphBuilder<BytesRef> bytesRefHnswGraphBuilder =
+ HnswGraphBuilder.create(
+ byteVectorValues,
+ fieldInfo.getVectorEncoding(),
+ fieldInfo.getVectorSimilarityFunction(),
+ M,
+ beamWidth,
+ HnswGraphBuilder.randSeed);
+ bytesRefHnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
+ graph = bytesRefHnswGraphBuilder.build(byteVectorValues.copy());
+ break;
+ case FLOAT32:
+ OffHeapVectorValues.DenseOffHeapVectorValues vectorValues =
+ new OffHeapVectorValues.DenseOffHeapVectorValues(
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality(),
+ vectorDataInput,
+ byteSize);
+ HnswGraphBuilder<float[]> hnswGraphBuilder =
+ HnswGraphBuilder.create(
+ vectorValues,
+ fieldInfo.getVectorEncoding(),
+ fieldInfo.getVectorSimilarityFunction(),
+ M,
+ beamWidth,
+ HnswGraphBuilder.randSeed);
+ hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
+ graph = hnswGraphBuilder.build(vectorValues.copy());
+ break;
+ default:
+ throw new IllegalArgumentException(
+ "unknown vector encoding=" + fieldInfo.getVectorEncoding());
+ }
writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
@@ -564,16 +606,37 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
}
}
+ /**
+ * Writes the byte vector values to the output and returns a set of documents that contains
+ * vectors.
+ */
+ private static DocsWithFieldSet writeByteVectorData(
+ IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
+ DocsWithFieldSet docsWithField = new DocsWithFieldSet();
+ for (int docV = byteVectorValues.nextDoc();
+ docV != NO_MORE_DOCS;
+ docV = byteVectorValues.nextDoc()) {
+ // write vector
+ BytesRef binaryValue = byteVectorValues.binaryValue();
+ assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
+ output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
+ docsWithField.add(docV);
+ }
+ return docsWithField;
+ }
+
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
private static DocsWithFieldSet writeVectorData(
- IndexOutput output, VectorValues vectors, int scalarSize) throws IOException {
+ IndexOutput output, VectorValues floatVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
- for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
+ for (int docV = floatVectorValues.nextDoc();
+ docV != NO_MORE_DOCS;
+ docV = floatVectorValues.nextDoc()) {
// write vector
- BytesRef binaryValue = vectors.binaryValue();
- assert binaryValue.length == vectors.dimension() * scalarSize;
+ BytesRef binaryValue = floatVectorValues.binaryValue();
+ assert binaryValue.length == floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docsWithField.add(docV);
}
@@ -590,7 +653,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
- private final RAVectorValues<T> raVectorValues;
private final HnswGraphBuilder<T> hnswGraphBuilder;
private int lastDocID = -1;
@@ -626,16 +688,15 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
- raVectorValues = new RAVectorValues<>(vectors, dim);
+ RAVectorValues<T> raVectorValues = new RAVectorValues<>(vectors, dim);
hnswGraphBuilder =
- (HnswGraphBuilder<T>)
- HnswGraphBuilder.create(
- raVectorValues,
- fieldInfo.getVectorEncoding(),
- fieldInfo.getVectorSimilarityFunction(),
- M,
- beamWidth,
- HnswGraphBuilder.randSeed);
+ HnswGraphBuilder.create(
+ raVectorValues,
+ fieldInfo.getVectorEncoding(),
+ fieldInfo.getVectorSimilarityFunction(),
+ M,
+ beamWidth,
+ HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
}
@@ -680,7 +741,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
}
}
- private static class RAVectorValues<T> implements RandomAccessVectorValues {
+ private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
private final List<T> vectors;
private final int dim;
@@ -700,17 +761,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
}
@Override
- public float[] vectorValue(int targetOrd) throws IOException {
- return (float[]) vectors.get(targetOrd);
- }
-
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- return (BytesRef) vectors.get(targetOrd);
+ public T vectorValue(int targetOrd) throws IOException {
+ return vectors.get(targetOrd);
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RAVectorValues<T> copy() throws IOException {
return this;
}
}
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
index 4994cf692d4..b575ab8fdbb 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java
@@ -25,6 +25,7 @@ import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
@@ -144,6 +145,39 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding());
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ FieldInfo info = readState.fieldInfos.fieldInfo(field);
+ if (info == null) {
+ // mirror the handling in Lucene90VectorReader#getVectorValues
+ // needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
+ return null;
+ }
+ int dimension = info.getVectorDimension();
+ if (dimension == 0) {
+ throw new IllegalStateException(
+ "KNN vectors readers should not be called on fields that don't enable KNN vectors");
+ }
+ FieldEntry fieldEntry = fieldEntries.get(field);
+ if (fieldEntry == null) {
+ // mirror the handling in Lucene90VectorReader#getVectorValues
+ // needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
+ return null;
+ }
+ if (dimension != fieldEntry.dimension) {
+ throw new IllegalStateException(
+ "Inconsistent vector dimension for field=\""
+ + field
+ + "\"; "
+ + dimension
+ + " != "
+ + fieldEntry.dimension);
+ }
+ IndexInput bytesSlice =
+ dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
+ return new SimpleTextByteVectorValues(fieldEntry, bytesSlice);
+ }
+
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
@@ -188,7 +222,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
@Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
- VectorValues values = getVectorValues(field);
+ ByteVectorValues values = getByteVectorValues(field);
if (target.length != values.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
@@ -214,7 +248,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
break;
}
- BytesRef vector = values.binaryValue();
+ BytesRef vector = values.vectorValue();
float score = vectorSimilarity.compare(vector, target);
topK.insertWithOverflow(new ScoreDoc(doc, score));
numVisited++;
@@ -302,7 +336,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
private static class SimpleTextVectorValues extends VectorValues
- implements RandomAccessVectorValues {
+ implements RandomAccessVectorValues<float[]> {
private final BytesRefBuilder scratch = new BytesRefBuilder();
private final FieldEntry entry;
@@ -357,7 +391,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
@Override
- public RandomAccessVectorValues copy() {
+ public RandomAccessVectorValues<float[]> copy() {
return this;
}
@@ -410,10 +444,99 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
public float[] vectorValue(int targetOrd) throws IOException {
return values[targetOrd];
}
+ }
+
+ private static class SimpleTextByteVectorValues extends ByteVectorValues
+ implements RandomAccessVectorValues<BytesRef> {
+
+ private final BytesRefBuilder scratch = new BytesRefBuilder();
+ private final FieldEntry entry;
+ private final IndexInput in;
+ private final BytesRef binaryValue;
+ private final byte[][] values;
+
+ int curOrd;
+
+ SimpleTextByteVectorValues(FieldEntry entry, IndexInput in) throws IOException {
+ this.entry = entry;
+ this.in = in;
+ values = new byte[entry.size()][entry.dimension];
+ binaryValue = new BytesRef(entry.dimension);
+ binaryValue.length = binaryValue.bytes.length;
+ curOrd = -1;
+ readAllVectors();
+ }
@Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- throw new UnsupportedOperationException();
+ public int dimension() {
+ return entry.dimension;
+ }
+
+ @Override
+ public int size() {
+ return entry.size();
+ }
+
+ @Override
+ public BytesRef vectorValue() {
+ binaryValue.bytes = values[curOrd];
+ return binaryValue;
+ }
+
+ @Override
+ public RandomAccessVectorValues<BytesRef> copy() {
+ return this;
+ }
+
+ @Override
+ public int docID() {
+ if (curOrd == -1) {
+ return -1;
+ } else if (curOrd >= entry.size()) {
+ // when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID
+ // immediately afterward should also return NO_MORE_DOCS
+ // this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case
+ return NO_MORE_DOCS;
+ }
+
+ return entry.ordToDoc[curOrd];
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ if (++curOrd < entry.size()) {
+ return docID();
+ }
+ return NO_MORE_DOCS;
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ return slowAdvance(target);
+ }
+
+ private void readAllVectors() throws IOException {
+ for (byte[] value : values) {
+ readVector(value);
+ }
+ }
+
+ private void readVector(byte[] value) throws IOException {
+ SimpleTextUtil.readLine(in, scratch);
+ // skip leading "[" and strip trailing "]"
+ String s = new BytesRef(scratch.bytes(), 1, scratch.length() - 2).utf8ToString();
+ String[] floatStrings = s.split(",");
+ assert floatStrings.length == value.length
+ : " read " + s + " when expecting " + value.length + " floats";
+ for (int i = 0; i < floatStrings.length; i++) {
+ value[i] = (byte) Float.parseFloat(floatStrings[i]);
+ }
+ }
+
+ @Override
+ public BytesRef vectorValue(int targetOrd) throws IOException {
+ binaryValue.bytes = values[curOrd];
+ return binaryValue;
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java
index 6010918fb57..760f959b2fd 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
@@ -85,6 +86,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
: vectorValues;
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
@@ -202,6 +208,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public void checkIntegrity() {}
};
@@ -228,7 +239,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
}
@Override
- public void addValue(int docID, Object value) {
+ public void addValue(int docID, float[] value) {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
@@ -236,31 +247,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
+ "\" appears more than once in this document (only one value is allowed per field)");
}
assert docID > lastDocID;
- float[] vectorValue;
- switch (fieldInfo.getVectorEncoding()) {
- case BYTE:
- vectorValue = bytesToFloats((BytesRef) value);
- break;
- default:
- case FLOAT32:
- vectorValue = (float[]) value;
- break;
- }
- ;
docsWithField.add(docID);
- vectors.add(copyValue(vectorValue));
+ vectors.add(copyValue(value));
lastDocID = docID;
}
- private float[] bytesToFloats(BytesRef b) {
- // This is used only by SimpleTextKnnVectorsWriter
- float[] floats = new float[dim];
- for (int i = 0; i < dim; i++) {
- floats[i] = b.bytes[i + b.offset];
- }
- return floats;
- }
-
@Override
public float[] copyValue(float[] vectorValue) {
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnFieldVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnFieldVectorsWriter.java
index d64f43e40d2..3e5b0fe1c6b 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnFieldVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnFieldVectorsWriter.java
@@ -34,7 +34,7 @@ public abstract class KnnFieldVectorsWriter<T> implements Accountable {
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
* increasing order.
*/
- public abstract void addValue(int docID, Object vectorValue) throws IOException;
+ public abstract void addValue(int docID, T vectorValue) throws IOException;
/**
* Used to copy values being indexed to internal storage.
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java
index 96a79cde891..2eca999f540 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java
@@ -18,6 +18,7 @@
package org.apache.lucene.codecs;
import java.io.IOException;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
@@ -98,6 +99,11 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
throw new UnsupportedOperationException();
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
index 6741674e01d..2199067c829 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
@@ -19,6 +19,7 @@ package org.apache.lucene.codecs;
import java.io.Closeable;
import java.io.IOException;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
@@ -51,6 +52,13 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
*/
public abstract VectorValues getVectorValues(String field) throws IOException;
+ /**
+ * Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if
+ * the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
+ * never {@code null}.
+ */
+ public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
+
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
index e45057ef37d..64aa8edfebe 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java
@@ -21,10 +21,12 @@ import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Accountable;
@@ -44,13 +46,29 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
/** Write field for merging */
@SuppressWarnings("unchecked")
- public <T> void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
- KnnFieldVectorsWriter<T> writer = (KnnFieldVectorsWriter<T>) addField(fieldInfo);
- VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
- for (int doc = mergedValues.nextDoc();
- doc != DocIdSetIterator.NO_MORE_DOCS;
- doc = mergedValues.nextDoc()) {
- writer.addValue(doc, mergedValues.vectorValue());
+ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ switch (fieldInfo.getVectorEncoding()) {
+ case BYTE:
+ KnnFieldVectorsWriter<BytesRef> byteWriter =
+ (KnnFieldVectorsWriter<BytesRef>) addField(fieldInfo);
+ ByteVectorValues mergedBytes =
+ MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
+ for (int doc = mergedBytes.nextDoc();
+ doc != DocIdSetIterator.NO_MORE_DOCS;
+ doc = mergedBytes.nextDoc()) {
+ byteWriter.addValue(doc, mergedBytes.vectorValue());
+ }
+ break;
+ case FLOAT32:
+ KnnFieldVectorsWriter<float[]> floatWriter =
+ (KnnFieldVectorsWriter<float[]>) addField(fieldInfo);
+ VectorValues mergedFloats = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
+ for (int doc = mergedFloats.nextDoc();
+ doc != DocIdSetIterator.NO_MORE_DOCS;
+ doc = mergedFloats.nextDoc()) {
+ floatWriter.addValue(doc, mergedFloats.vectorValue());
+ }
+ break;
}
}
@@ -104,20 +122,34 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
}
}
- /** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
- protected static class MergedVectorValues extends VectorValues {
- private final List<VectorValuesSub> subs;
- private final DocIDMerger<VectorValuesSub> docIdMerger;
- private final int size;
+ private static class ByteVectorValuesSub extends DocIDMerger.Sub {
+
+ final ByteVectorValues values;
+
+ ByteVectorValuesSub(MergeState.DocMap docMap, ByteVectorValues values) {
+ super(docMap);
+ this.values = values;
+ assert values.docID() == -1;
+ }
- private int docId;
- private VectorValuesSub current;
+ @Override
+ public int nextDoc() throws IOException {
+ return values.nextDoc();
+ }
+ }
+
+ /** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
+ protected static final class MergedVectorValues {
+ private MergedVectorValues() {}
/** Returns a merged view over all the segment's {@link VectorValues}. */
- public static MergedVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
+ public static VectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues();
-
+ if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
+ throw new UnsupportedOperationException(
+ "Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32");
+ }
List<VectorValuesSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
@@ -128,60 +160,147 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
}
}
}
- return new MergedVectorValues(subs, mergeState);
+ return new MergedFloat32VectorValues(subs, mergeState);
}
- private MergedVectorValues(List<VectorValuesSub> subs, MergeState mergeState)
+ /** Returns a merged view over all the segment's {@link ByteVectorValues}. */
+ public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException {
- this.subs = subs;
- docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
- int totalSize = 0;
- for (VectorValuesSub sub : subs) {
- totalSize += sub.values.size();
- }
- size = totalSize;
- docId = -1;
+ assert fieldInfo != null && fieldInfo.hasVectorValues();
+ if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) {
+ throw new UnsupportedOperationException(
+ "Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE");
+ }
+ List<ByteVectorValuesSub> subs = new ArrayList<>();
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+ if (knnVectorsReader != null) {
+ ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.name);
+ if (values != null) {
+ subs.add(new ByteVectorValuesSub(mergeState.docMaps[i], values));
+ }
+ }
+ }
+ return new MergedByteVectorValues(subs, mergeState);
}
- @Override
- public int docID() {
- return docId;
- }
+ static class MergedFloat32VectorValues extends VectorValues {
+ private final List<VectorValuesSub> subs;
+ private final DocIDMerger<VectorValuesSub> docIdMerger;
+ private final int size;
- @Override
- public int nextDoc() throws IOException {
- current = docIdMerger.next();
- if (current == null) {
- docId = NO_MORE_DOCS;
- } else {
- docId = current.mappedDocID;
+ private int docId;
+ VectorValuesSub current;
+
+ private MergedFloat32VectorValues(List<VectorValuesSub> subs, MergeState mergeState)
+ throws IOException {
+ this.subs = subs;
+ docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
+ int totalSize = 0;
+ for (VectorValuesSub sub : subs) {
+ totalSize += sub.values.size();
+ }
+ size = totalSize;
+ docId = -1;
}
- return docId;
- }
- @Override
- public float[] vectorValue() throws IOException {
- return current.values.vectorValue();
- }
+ @Override
+ public int docID() {
+ return docId;
+ }
- @Override
- public BytesRef binaryValue() throws IOException {
- return current.values.binaryValue();
- }
+ @Override
+ public int nextDoc() throws IOException {
+ current = docIdMerger.next();
+ if (current == null) {
+ docId = NO_MORE_DOCS;
+ } else {
+ docId = current.mappedDocID;
+ }
+ return docId;
+ }
- @Override
- public int advance(int target) {
- throw new UnsupportedOperationException();
- }
+ @Override
+ public float[] vectorValue() throws IOException {
+ return current.values.vectorValue();
+ }
- @Override
- public int size() {
- return size;
+ @Override
+ public BytesRef binaryValue() throws IOException {
+ return current.values.binaryValue();
+ }
+
+ @Override
+ public int advance(int target) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public int dimension() {
+ return subs.get(0).values.dimension();
+ }
}
- @Override
- public int dimension() {
- return subs.get(0).values.dimension();
+ static class MergedByteVectorValues extends ByteVectorValues {
+ private final List<ByteVectorValuesSub> subs;
+ private final DocIDMerger<ByteVectorValuesSub> docIdMerger;
+ private final int size;
+
+ private int docId;
+ ByteVectorValuesSub current;
+
+ private MergedByteVectorValues(List<ByteVectorValuesSub> subs, MergeState mergeState)
+ throws IOException {
+ this.subs = subs;
+ docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
+ int totalSize = 0;
+ for (ByteVectorValuesSub sub : subs) {
+ totalSize += sub.values.size();
+ }
+ size = totalSize;
+ docId = -1;
+ }
+
+ @Override
+ public BytesRef vectorValue() throws IOException {
+ return current.values.vectorValue();
+ }
+
+ @Override
+ public int docID() {
+ return docId;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ current = docIdMerger.next();
+ if (current == null) {
+ docId = NO_MORE_DOCS;
+ } else {
+ docId = current.mappedDocID;
+ }
+ return docId;
+ }
+
+ @Override
+ public int advance(int target) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public int dimension() {
+ return subs.get(0).values.dimension();
+ }
}
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/ExpandingVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/ExpandingVectorValues.java
deleted file mode 100644
index 3f62cdb7cb6..00000000000
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/ExpandingVectorValues.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.lucene.codecs.lucene95;
-
-import java.io.IOException;
-import org.apache.lucene.index.FilterVectorValues;
-import org.apache.lucene.index.VectorValues;
-import org.apache.lucene.util.BytesRef;
-
-/** reads from byte-encoded data */
-class ExpandingVectorValues extends FilterVectorValues {
-
- private final float[] value;
-
- /**
- * @param in the wrapped values
- */
- protected ExpandingVectorValues(VectorValues in) {
- super(in);
- value = new float[in.dimension()];
- }
-
- @Override
- public float[] vectorValue() throws IOException {
- BytesRef binaryValue = binaryValue();
- byte[] bytes = binaryValue.bytes;
- for (int i = 0, j = binaryValue.offset; i < value.length; i++, j++) {
- value[i] = bytes[j];
- }
- return value;
- }
-}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java
index e22d01cb06c..e278d50c581 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
@@ -246,12 +247,31 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
@Override
public VectorValues getVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
- VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData);
- if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) {
- return new ExpandingVectorValues(values);
- } else {
- return values;
+ if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
+ throw new IllegalArgumentException(
+ "field=\""
+ + field
+ + "\" is encoded as: "
+ + fieldEntry.vectorEncoding
+ + " expected: "
+ + VectorEncoding.FLOAT32);
}
+ return OffHeapVectorValues.load(fieldEntry, vectorData);
+ }
+
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ FieldEntry fieldEntry = fields.get(field);
+ if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
+ throw new IllegalArgumentException(
+ "field=\""
+ + field
+ + "\" is encoded as: "
+ + fieldEntry.vectorEncoding
+ + " expected: "
+ + VectorEncoding.FLOAT32);
+ }
+ return OffHeapByteVectorValues.load(fieldEntry, vectorData);
}
@Override
@@ -311,7 +331,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());
- OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
+ OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
NeighborQueue results =
HnswGraphSearcher.search(
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java
index 12e0f0bd73e..1963a199cfd 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java
@@ -403,8 +403,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
- VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
-
IndexOutput tempVectorData =
segmentWriteState.directory.createTempOutput(
vectorData.getName(), "temp", segmentWriteState.context);
@@ -412,8 +410,24 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
boolean success = false;
try {
// write the vector data to a temporary file
- DocsWithFieldSet docsWithField =
- writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize);
+ // write the vector data to a temporary file
+ final DocsWithFieldSet docsWithField;
+ switch (fieldInfo.getVectorEncoding()) {
+ case BYTE:
+ docsWithField =
+ writeByteVectorData(
+ tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
+ break;
+ case FLOAT32:
+ docsWithField =
+ writeVectorData(
+ tempVectorData, MergedVectorValues.mergeVectorValues(fieldInfo, mergeState));
+ break;
+ default:
+ throw new IllegalArgumentException(
+ "unknown vector encoding=" + fieldInfo.getVectorEncoding());
+ }
+ ;
CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData);
@@ -429,24 +443,52 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
// we use Lucene95HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
// doesn't need to know docIds
// TODO: separate random access vector values from DocIdSetIterator?
- int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize;
- OffHeapVectorValues offHeapVectors =
- new OffHeapVectorValues.DenseOffHeapVectorValues(
- vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
+ int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
OnHeapHnswGraph graph = null;
int[][] vectorIndexNodeOffsets = null;
- if (offHeapVectors.size() != 0) {
+ if (docsWithField.cardinality() != 0) {
// build graph
- HnswGraphBuilder<?> hnswGraphBuilder =
- HnswGraphBuilder.create(
- offHeapVectors,
- fieldInfo.getVectorEncoding(),
- fieldInfo.getVectorSimilarityFunction(),
- M,
- beamWidth,
- HnswGraphBuilder.randSeed);
- hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
- graph = hnswGraphBuilder.build(offHeapVectors.copy());
+ switch (fieldInfo.getVectorEncoding()) {
+ case BYTE:
+ OffHeapByteVectorValues.DenseOffHeapVectorValues byteVectorValues =
+ new OffHeapByteVectorValues.DenseOffHeapVectorValues(
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality(),
+ vectorDataInput,
+ byteSize);
+ HnswGraphBuilder<BytesRef> bytesRefHnswGraphBuilder =
+ HnswGraphBuilder.create(
+ byteVectorValues,
+ fieldInfo.getVectorEncoding(),
+ fieldInfo.getVectorSimilarityFunction(),
+ M,
+ beamWidth,
+ HnswGraphBuilder.randSeed);
+ bytesRefHnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
+ graph = bytesRefHnswGraphBuilder.build(byteVectorValues.copy());
+ break;
+ case FLOAT32:
+ OffHeapVectorValues.DenseOffHeapVectorValues vectorValues =
+ new OffHeapVectorValues.DenseOffHeapVectorValues(
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality(),
+ vectorDataInput,
+ byteSize);
+ HnswGraphBuilder<float[]> hnswGraphBuilder =
+ HnswGraphBuilder.create(
+ vectorValues,
+ fieldInfo.getVectorEncoding(),
+ fieldInfo.getVectorSimilarityFunction(),
+ M,
+ beamWidth,
+ HnswGraphBuilder.randSeed);
+ hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
+ graph = hnswGraphBuilder.build(vectorValues.copy());
+ break;
+ default:
+ throw new IllegalArgumentException(
+ "unknown vector encoding=" + fieldInfo.getVectorEncoding());
+ }
vectorIndexNodeOffsets = writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
@@ -617,16 +659,37 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
}
}
+ /**
+ * Writes the byte vector values to the output and returns a set of documents that contains
+ * vectors.
+ */
+ private static DocsWithFieldSet writeByteVectorData(
+ IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
+ DocsWithFieldSet docsWithField = new DocsWithFieldSet();
+ for (int docV = byteVectorValues.nextDoc();
+ docV != NO_MORE_DOCS;
+ docV = byteVectorValues.nextDoc()) {
+ // write vector
+ BytesRef binaryValue = byteVectorValues.binaryValue();
+ assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
+ output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
+ docsWithField.add(docV);
+ }
+ return docsWithField;
+ }
+
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
private static DocsWithFieldSet writeVectorData(
- IndexOutput output, VectorValues vectors, int scalarSize) throws IOException {
+ IndexOutput output, VectorValues floatVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
- for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
+ for (int docV = floatVectorValues.nextDoc();
+ docV != NO_MORE_DOCS;
+ docV = floatVectorValues.nextDoc()) {
// write vector
- BytesRef binaryValue = vectors.binaryValue();
- assert binaryValue.length == vectors.dimension() * scalarSize;
+ BytesRef binaryValue = floatVectorValues.binaryValue();
+ assert binaryValue.length == floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docsWithField.add(docV);
}
@@ -643,7 +706,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
- private final RAVectorValues<T> raVectorValues;
private final HnswGraphBuilder<T> hnswGraphBuilder;
private int lastDocID = -1;
@@ -673,36 +735,31 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
}
}
- @SuppressWarnings("unchecked")
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException {
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
- raVectorValues = new RAVectorValues<>(vectors, dim);
hnswGraphBuilder =
- (HnswGraphBuilder<T>)
- HnswGraphBuilder.create(
- raVectorValues,
- fieldInfo.getVectorEncoding(),
- fieldInfo.getVectorSimilarityFunction(),
- M,
- beamWidth,
- HnswGraphBuilder.randSeed);
+ HnswGraphBuilder.create(
+ new RAVectorValues<>(vectors, dim),
+ fieldInfo.getVectorEncoding(),
+ fieldInfo.getVectorSimilarityFunction(),
+ M,
+ beamWidth,
+ HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
}
@Override
- @SuppressWarnings("unchecked")
- public void addValue(int docID, Object value) throws IOException {
+ public void addValue(int docID, T vectorValue) throws IOException {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
- T vectorValue = (T) value;
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(copyValue(vectorValue));
@@ -735,7 +792,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
}
}
- private static class RAVectorValues<T> implements RandomAccessVectorValues {
+ private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
private final List<T> vectors;
private final int dim;
@@ -755,17 +812,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
}
@Override
- public float[] vectorValue(int targetOrd) throws IOException {
- return (float[]) vectors.get(targetOrd);
- }
-
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- return (BytesRef) vectors.get(targetOrd);
+ public T vectorValue(int targetOrd) throws IOException {
+ return vectors.get(targetOrd);
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<T> copy() throws IOException {
return this;
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java
similarity index 76%
copy from lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapVectorValues.java
copy to lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java
index 124430bceb4..05c7e2204a7 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapVectorValues.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java
@@ -20,7 +20,8 @@ package org.apache.lucene.codecs.lucene95;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
-import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
@@ -29,7 +30,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
-abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
+abstract class OffHeapByteVectorValues extends ByteVectorValues
+ implements RandomAccessVectorValues<BytesRef> {
protected final int dimension;
protected final int size;
@@ -37,15 +39,13 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
protected final BytesRef binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
- protected final float[] value;
- OffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
+ OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.byteSize = byteSize;
byteBuffer = ByteBuffer.allocate(byteSize);
- value = new float[dimension];
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
@@ -60,14 +60,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public float[] vectorValue(int targetOrd) throws IOException {
- slice.seek((long) targetOrd * byteSize);
- slice.readFloats(value, 0, value.length);
- return value;
- }
-
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
+ public BytesRef vectorValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
@@ -79,24 +72,14 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
public abstract int ordToDoc(int ord);
- static OffHeapVectorValues load(
+ static OffHeapByteVectorValues load(
Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
- if (fieldEntry.docsWithFieldOffset == -2) {
+ if (fieldEntry.docsWithFieldOffset == -2 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
}
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
- int byteSize;
- switch (fieldEntry.vectorEncoding) {
- case BYTE:
- byteSize = fieldEntry.dimension;
- break;
- case FLOAT32:
- byteSize = fieldEntry.dimension * Float.BYTES;
- break;
- default:
- throw new AssertionError();
- }
+ int byteSize = fieldEntry.dimension;
if (fieldEntry.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues(
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
@@ -107,7 +90,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
abstract Bits getAcceptOrds(Bits acceptDocs);
- static class DenseOffHeapVectorValues extends OffHeapVectorValues {
+ static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
private int doc = -1;
@@ -116,16 +99,9 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public float[] vectorValue() throws IOException {
+ public BytesRef vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
- slice.readFloats(value, 0, value.length);
- return value;
- }
-
- @Override
- public BytesRef binaryValue() throws IOException {
- slice.seek((long) doc * byteSize);
- slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
+ slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
return binaryValue;
}
@@ -149,7 +125,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<BytesRef> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@@ -164,7 +140,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
}
- private static class SparseOffHeapVectorValues extends OffHeapVectorValues {
+ private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
private final DirectMonotonicReader ordToDoc;
private final IndexedDISI disi;
// dataIn was used to init a new IndexedDIS for #randomAccess()
@@ -195,14 +171,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public float[] vectorValue() throws IOException {
- slice.seek((long) (disi.index()) * byteSize);
- slice.readFloats(value, 0, value.length);
- return value;
- }
-
- @Override
- public BytesRef binaryValue() throws IOException {
+ public BytesRef vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
return binaryValue;
@@ -225,7 +194,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<BytesRef> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@@ -253,7 +222,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
}
- private static class EmptyOffHeapVectorValues extends OffHeapVectorValues {
+ private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null, 0);
@@ -272,12 +241,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public float[] vectorValue() throws IOException {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public BytesRef binaryValue() throws IOException {
+ public BytesRef vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@@ -297,17 +261,12 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public float[] vectorValue(int targetOrd) throws IOException {
+ public RandomAccessVectorValues<BytesRef> copy() throws IOException {
throw new UnsupportedOperationException();
}
@Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
+ public BytesRef vectorValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapVectorValues.java
index 124430bceb4..fdee80b186e 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapVectorValues.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapVectorValues.java
@@ -20,6 +20,7 @@ package org.apache.lucene.codecs.lucene95;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
+import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
@@ -29,7 +30,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
-abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
+abstract class OffHeapVectorValues extends VectorValues
+ implements RandomAccessVectorValues<float[]> {
protected final int dimension;
protected final int size;
@@ -66,37 +68,17 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
return value;
}
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- readValue(targetOrd);
- return binaryValue;
- }
-
- private void readValue(int targetOrd) throws IOException {
- slice.seek((long) targetOrd * byteSize);
- slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
- }
-
public abstract int ordToDoc(int ord);
static OffHeapVectorValues load(
Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
- if (fieldEntry.docsWithFieldOffset == -2) {
+ if (fieldEntry.docsWithFieldOffset == -2
+ || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
}
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
- int byteSize;
- switch (fieldEntry.vectorEncoding) {
- case BYTE:
- byteSize = fieldEntry.dimension;
- break;
- case FLOAT32:
- byteSize = fieldEntry.dimension * Float.BYTES;
- break;
- default:
- throw new AssertionError();
- }
+ int byteSize = fieldEntry.dimension * Float.BYTES;
if (fieldEntry.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues(
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
@@ -149,7 +131,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<float[]> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@@ -225,7 +207,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<float[]> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@@ -297,7 +279,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
}
@Override
- public RandomAccessVectorValues copy() throws IOException {
+ public RandomAccessVectorValues<float[]> copy() throws IOException {
throw new UnsupportedOperationException();
}
@@ -306,11 +288,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
throw new UnsupportedOperationException();
}
- @Override
- public BytesRef binaryValue(int targetOrd) throws IOException {
- throw new UnsupportedOperationException();
- }
-
@Override
public int ordToDoc(int ord) {
throw new UnsupportedOperationException();
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
index 428311e92e4..9df9dbfba27 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
@@ -27,6 +27,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState;
@@ -255,6 +256,16 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
}
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ KnnVectorsReader knnVectorsReader = fields.get(field);
+ if (knnVectorsReader == null) {
+ return null;
+ } else {
+ return knnVectorsReader.getByteVectorValues(field);
+ }
+ }
+
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java
similarity index 53%
copy from lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java
copy to lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java
index 2376f1806d3..0226873f932 100644
--- a/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java
+++ b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java
@@ -17,103 +17,75 @@
package org.apache.lucene.document;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
-import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.search.KnnByteVectorQuery;
+import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
-import org.apache.lucene.util.VectorUtil;
/**
- * A field that contains a single floating-point numeric vector (or none) for each document. Vectors
- * are dense - that is, every dimension of a vector contains an explicit value, stored packed into
- * an array (of type float[]) whose length is the vector dimension. Values can be retrieved using
- * {@link VectorValues}, which is a forward-only docID-based iterator and also offers random-access
- * by dense ordinal (not docId). {@link VectorSimilarityFunction} may be used to compare vectors at
- * query time (for example as part of result ranking). A KnnVectorField may be associated with a
+ * A field that contains a single byte numeric vector (or none) for each document. Vectors are dense
+ * - that is, every dimension of a vector contains an explicit value, stored packed into an array
+ * (of type byte[]) whose length is the vector dimension. Values can be retrieved using {@link
+ * ByteVectorValues}, which is a forward-only docID-based iterator and also offers random-access by
+ * dense ordinal (not docId). {@link VectorSimilarityFunction} may be used to compare vectors at
+ * query time (for example as part of result ranking). A KnnByteVectorField may be associated with a
* search similarity function defining the metric used for nearest-neighbor search among vectors of
* that field.
*
* @lucene.experimental
*/
-public class KnnVectorField extends Field {
-
- private static FieldType createType(float[] v, VectorSimilarityFunction similarityFunction) {
- if (v == null) {
- throw new IllegalArgumentException("vector value must not be null");
- }
- return createType(v.length, VectorEncoding.FLOAT32, similarityFunction);
- }
+public class KnnByteVectorField extends Field {
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
if (v == null) {
throw new IllegalArgumentException("vector value must not be null");
}
- return createType(v.length, VectorEncoding.BYTE, similarityFunction);
- }
-
- private static FieldType createType(
- int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
+ int dimension = v.length;
if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector");
}
- if (dimension > VectorValues.MAX_DIMENSIONS) {
+ if (dimension > ByteVectorValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException(
- "cannot index vectors with dimension greater than " + VectorValues.MAX_DIMENSIONS);
+ "cannot index vectors with dimension greater than " + ByteVectorValues.MAX_DIMENSIONS);
}
if (similarityFunction == null) {
throw new IllegalArgumentException("similarity function must not be null");
}
FieldType type = new FieldType();
- type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
+ type.setVectorAttributes(dimension, VectorEncoding.BYTE, similarityFunction);
type.freeze();
return type;
}
/**
- * A convenience method for creating a vector field type with the default FLOAT32 encoding.
+ * Create a new vector query for the provided field targeting the byte vector
*
- * @param dimension dimension of vectors
- * @param similarityFunction a function defining vector proximity.
- * @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
+ * @param field The field to query
+ * @param queryVector The byte vector target
+ * @param k The number of nearest neighbors to gather
+ * @return A new vector query
*/
- public static FieldType createFieldType(
- int dimension, VectorSimilarityFunction similarityFunction) {
- return createFieldType(dimension, VectorEncoding.FLOAT32, similarityFunction);
+ public static Query newVectorQuery(String field, BytesRef queryVector, int k) {
+ return new KnnByteVectorQuery(field, queryVector, k);
}
/**
* A convenience method for creating a vector field type.
*
* @param dimension dimension of vectors
- * @param vectorEncoding the encoding of the scalar values
* @param similarityFunction a function defining vector proximity.
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
*/
public static FieldType createFieldType(
- int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
+ int dimension, VectorSimilarityFunction similarityFunction) {
FieldType type = new FieldType();
- type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
+ type.setVectorAttributes(dimension, VectorEncoding.BYTE, similarityFunction);
type.freeze();
return type;
}
- /**
- * Creates a numeric vector field. Fields are single-valued: each document has either one value or
- * no value. Vectors of a single field share the same dimension and similarity function. Note that
- * some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
- * be unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
- *
- * @param name field name
- * @param vector value
- * @param similarityFunction a function defining vector proximity.
- * @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
- * dimension > 1024.
- */
- public KnnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
- super(name, createType(vector, similarityFunction));
- fieldsData = vector;
- }
-
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function. Note that
@@ -126,7 +98,8 @@ public class KnnVectorField extends Field {
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension > 1024.
*/
- public KnnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
+ public KnnByteVectorField(
+ String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
}
@@ -141,7 +114,7 @@ public class KnnVectorField extends Field {
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension > 1024.
*/
- public KnnVectorField(String name, float[] vector) {
+ public KnnByteVectorField(String name, BytesRef vector) {
this(name, vector, VectorSimilarityFunction.EUCLIDEAN);
}
@@ -155,43 +128,21 @@ public class KnnVectorField extends Field {
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension > 1024.
*/
- public KnnVectorField(String name, float[] vector, FieldType fieldType) {
- super(name, fieldType);
- if (fieldType.vectorEncoding() != VectorEncoding.FLOAT32) {
- throw new IllegalArgumentException(
- "Attempt to create a vector for field "
- + name
- + " using float[] but the field encoding is "
- + fieldType.vectorEncoding());
- }
- fieldsData = vector;
- }
-
- /**
- * Creates a numeric vector field. Fields are single-valued: each document has either one value or
- * no value. Vectors of a single field share the same dimension and similarity function.
- *
- * @param name field name
- * @param vector value
- * @param fieldType field type
- * @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
- * dimension > 1024.
- */
- public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
+ public KnnByteVectorField(String name, BytesRef vector, FieldType fieldType) {
super(name, fieldType);
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"Attempt to create a vector for field "
+ name
- + " using BytesRef but the field encoding is "
+ + " using byte[] but the field encoding is "
+ fieldType.vectorEncoding());
}
fieldsData = vector;
}
/** Return the vector value of this field */
- public float[] vectorValue() {
- return (float[]) fieldsData;
+ public BytesRef vectorValue() {
+ return (BytesRef) fieldsData;
}
/**
@@ -199,7 +150,7 @@ public class KnnVectorField extends Field {
*
* @param value the value to set; must not be null, and length must match the field type
*/
- public void setVectorValue(float[] value) {
+ public void setVectorValue(BytesRef value) {
if (value == null) {
throw new IllegalArgumentException("value must not be null");
}
diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java
index 2376f1806d3..2518e01dbc6 100644
--- a/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java
+++ b/lucene/core/src/java/org/apache/lucene/document/KnnVectorField.java
@@ -20,7 +20,8 @@ package org.apache.lucene.document;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
-import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.search.KnnVectorQuery;
+import org.apache.lucene.search.Query;
import org.apache.lucene.util.VectorUtil;
/**
@@ -41,18 +42,7 @@ public class KnnVectorField extends Field {
if (v == null) {
throw new IllegalArgumentException("vector value must not be null");
}
- return createType(v.length, VectorEncoding.FLOAT32, similarityFunction);
- }
-
- private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
- if (v == null) {
- throw new IllegalArgumentException("vector value must not be null");
- }
- return createType(v.length, VectorEncoding.BYTE, similarityFunction);
- }
-
- private static FieldType createType(
- int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
+ int dimension = v.length;
if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector");
}
@@ -64,61 +54,43 @@ public class KnnVectorField extends Field {
throw new IllegalArgumentException("similarity function must not be null");
}
FieldType type = new FieldType();
- type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
+ type.setVectorAttributes(dimension, VectorEncoding.FLOAT32, similarityFunction);
type.freeze();
return type;
}
- /**
- * A convenience method for creating a vector field type with the default FLOAT32 encoding.
- *
- * @param dimension dimension of vectors
- * @param similarityFunction a function defining vector proximity.
- * @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
- */
- public static FieldType createFieldType(
- int dimension, VectorSimilarityFunction similarityFunction) {
- return createFieldType(dimension, VectorEncoding.FLOAT32, similarityFunction);
- }
-
/**
* A convenience method for creating a vector field type.
*
* @param dimension dimension of vectors
- * @param vectorEncoding the encoding of the scalar values
* @param similarityFunction a function defining vector proximity.
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
*/
public static FieldType createFieldType(
- int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
+ int dimension, VectorSimilarityFunction similarityFunction) {
FieldType type = new FieldType();
- type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
+ type.setVectorAttributes(dimension, VectorEncoding.FLOAT32, similarityFunction);
type.freeze();
return type;
}
/**
- * Creates a numeric vector field. Fields are single-valued: each document has either one value or
- * no value. Vectors of a single field share the same dimension and similarity function. Note that
- * some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
- * be unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
+ * Create a new vector query for the provided field targeting the float vector
*
- * @param name field name
- * @param vector value
- * @param similarityFunction a function defining vector proximity.
- * @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
- * dimension > 1024.
+ * @param field The field to query
+ * @param queryVector The float vector target
+ * @param k The number of nearest neighbors to gather
+ * @return A new vector query
*/
- public KnnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
- super(name, createType(vector, similarityFunction));
- fieldsData = vector;
+ public static Query newVectorQuery(String field, float[] queryVector, int k) {
+ return new KnnVectorQuery(field, queryVector, k);
}
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function. Note that
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
- * be constant-length.
+ * be unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
*
* @param name field name
* @param vector value
@@ -126,7 +98,7 @@ public class KnnVectorField extends Field {
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension > 1024.
*/
- public KnnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
+ public KnnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
}
@@ -167,28 +139,6 @@ public class KnnVectorField extends Field {
fieldsData = vector;
}
- /**
- * Creates a numeric vector field. Fields are single-valued: each document has either one value or
- * no value. Vectors of a single field share the same dimension and similarity function.
- *
- * @param name field name
- * @param vector value
- * @param fieldType field type
- * @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
- * dimension > 1024.
- */
- public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
- super(name, fieldType);
- if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
- throw new IllegalArgumentException(
- "Attempt to create a vector for field "
- + name
- + " using BytesRef but the field encoding is "
- + fieldType.vectorEncoding());
- }
- fieldsData = vector;
- }
-
/** Return the vector value of this field */
public float[] vectorValue() {
return (float[]) fieldsData;
diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java
similarity index 87%
copy from lucene/core/src/java/org/apache/lucene/index/VectorValues.java
copy to lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java
index 79b1033b4aa..eaac008388c 100644
--- a/lucene/core/src/java/org/apache/lucene/index/VectorValues.java
+++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java
@@ -17,23 +17,23 @@
package org.apache.lucene.index;
import java.io.IOException;
-import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
/**
* This class provides access to per-document floating point vector values indexed as {@link
- * KnnVectorField}.
+ * KnnByteVectorField}.
*
* @lucene.experimental
*/
-public abstract class VectorValues extends DocIdSetIterator {
+public abstract class ByteVectorValues extends DocIdSetIterator {
/** The maximum length of a vector */
public static final int MAX_DIMENSIONS = 1024;
/** Sole constructor */
- protected VectorValues() {}
+ protected ByteVectorValues() {}
/** Return the dimension of the vectors */
public abstract int dimension();
@@ -57,7 +57,7 @@ public abstract class VectorValues extends DocIdSetIterator {
*
* @return the vector value
*/
- public abstract float[] vectorValue() throws IOException;
+ public abstract BytesRef vectorValue() throws IOException;
/**
* Return the binary encoded vector value for the current document ID. These are the bytes
@@ -67,7 +67,7 @@ public abstract class VectorValues extends DocIdSetIterator {
*
* @return the binary value
*/
- public BytesRef binaryValue() throws IOException {
- throw new UnsupportedOperationException();
+ public final BytesRef binaryValue() throws IOException {
+ return vectorValue();
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
index b9495832ba2..730307dd380 100644
--- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
+++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
@@ -34,6 +34,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
+import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
@@ -2588,71 +2589,37 @@ public final class CheckIndex implements Closeable {
+ "\" has vector values but dimension is "
+ dimension);
}
- VectorValues values = reader.getVectorValues(fieldInfo.name);
- if (values == null) {
+ if (reader.getVectorValues(fieldInfo.name) == null
+ && reader.getByteVectorValues(fieldInfo.name) == null) {
continue;
}
status.totalKnnVectorFields++;
-
- int docCount = 0;
- int everyNdoc = Math.max(values.size() / 64, 1);
- while (values.nextDoc() != NO_MORE_DOCS) {
- // search the first maxNumSearches vectors to exercise the graph
- if (values.docID() % everyNdoc == 0) {
- final TopDocs docs;
- switch (fieldInfo.getVectorEncoding()) {
- case FLOAT32:
- docs =
- reader
- .getVectorReader()
- .search(
- fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
- break;
- case BYTE:
- docs =
- reader
- .getVectorReader()
- .search(
- fieldInfo.name, values.binaryValue(), 10, null, Integer.MAX_VALUE);
- break;
- default:
- throw new IllegalArgumentException(
- "unknown vector encoding: " + fieldInfo.getVectorEncoding());
- }
- if (docs.scoreDocs.length == 0) {
- throw new CheckIndexException(
- "Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
- }
- }
- float[] vectorValue = values.vectorValue();
- int valueLength = vectorValue.length;
- if (valueLength != dimension) {
+ switch (fieldInfo.getVectorEncoding()) {
+ case BYTE:
+ checkByteVectorValues(
+ Objects.requireNonNull(reader.getByteVectorValues(fieldInfo.name)),
+ fieldInfo,
+ status,
+ reader);
+ break;
+ case FLOAT32:
+ checkFloatVectorValues(
+ Objects.requireNonNull(reader.getVectorValues(fieldInfo.name)),
+ fieldInfo,
+ status,
+ reader);
+ break;
+ default:
throw new CheckIndexException(
"Field \""
+ fieldInfo.name
- + "\" has a value whose dimension="
- + valueLength
- + " not matching the field's dimension="
- + dimension);
- }
- ++docCount;
+ + "\" has unexpected vector encoding: "
+ + fieldInfo.getVectorEncoding());
}
- if (docCount != values.size()) {
- throw new CheckIndexException(
- "Field \""
- + fieldInfo.name
- + "\" has size="
- + values.size()
- + " but when iterated, returns "
- + docCount
- + " docs with values");
- }
- status.totalVectorValues += docCount;
}
}
}
-
msg(
infoStream,
String.format(
@@ -2676,6 +2643,96 @@ public final class CheckIndex implements Closeable {
return status;
}
+ private static void checkFloatVectorValues(
+ VectorValues values,
+ FieldInfo fieldInfo,
+ CheckIndex.Status.VectorValuesStatus status,
+ CodecReader codecReader)
+ throws IOException {
+ int docCount = 0;
+ int everyNdoc = Math.max(values.size() / 64, 1);
+ while (values.nextDoc() != NO_MORE_DOCS) {
+ // search the first maxNumSearches vectors to exercise the graph
+ if (values.docID() % everyNdoc == 0) {
+ TopDocs docs =
+ codecReader
+ .getVectorReader()
+ .search(fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
+ if (docs.scoreDocs.length == 0) {
+ throw new CheckIndexException(
+ "Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
+ }
+ }
+ int valueLength = values.vectorValue().length;
+ if (valueLength != fieldInfo.getVectorDimension()) {
+ throw new CheckIndexException(
+ "Field \""
+ + fieldInfo.name
+ + "\" has a value whose dimension="
+ + valueLength
+ + " not matching the field's dimension="
+ + fieldInfo.getVectorDimension());
+ }
+ ++docCount;
+ }
+ if (docCount != values.size()) {
+ throw new CheckIndexException(
+ "Field \""
+ + fieldInfo.name
+ + "\" has size="
+ + values.size()
+ + " but when iterated, returns "
+ + docCount
+ + " docs with values");
+ }
+ status.totalVectorValues += docCount;
+ }
+
+ private static void checkByteVectorValues(
+ ByteVectorValues values,
+ FieldInfo fieldInfo,
+ CheckIndex.Status.VectorValuesStatus status,
+ CodecReader codecReader)
+ throws IOException {
+ int docCount = 0;
+ int everyNdoc = Math.max(values.size() / 64, 1);
+ while (values.nextDoc() != NO_MORE_DOCS) {
+ // search the first maxNumSearches vectors to exercise the graph
+ if (values.docID() % everyNdoc == 0) {
+ TopDocs docs =
+ codecReader
+ .getVectorReader()
+ .search(fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
+ if (docs.scoreDocs.length == 0) {
+ throw new CheckIndexException(
+ "Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
+ }
+ }
+ int valueLength = values.vectorValue().length;
+ if (valueLength != fieldInfo.getVectorDimension()) {
+ throw new CheckIndexException(
+ "Field \""
+ + fieldInfo.name
+ + "\" has a value whose dimension="
+ + valueLength
+ + " not matching the field's dimension="
+ + fieldInfo.getVectorDimension());
+ }
+ ++docCount;
+ }
+ if (docCount != values.size()) {
+ throw new CheckIndexException(
+ "Field \""
+ + fieldInfo.name
+ + "\" has size="
+ + values.size()
+ + " but when iterated, returns "
+ + docCount
+ + " docs with values");
+ }
+ status.totalVectorValues += docCount;
+ }
+
/**
* Walks the entire N-dimensional points space, verifying that all points fall within the last
* cell's boundaries.
diff --git a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java
index 70a4afd014b..4ddb1c5d17c 100644
--- a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java
@@ -233,7 +233,9 @@ public abstract class CodecReader extends LeafReader {
public final VectorValues getVectorValues(String field) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
- if (fi == null || fi.getVectorDimension() == 0) {
+ if (fi == null
+ || fi.getVectorDimension() == 0
+ || fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
// Field does not exist or does not index vectors
return null;
}
@@ -241,6 +243,20 @@ public abstract class CodecReader extends LeafReader {
return getVectorReader().getVectorValues(field);
}
+ @Override
+ public final ByteVectorValues getByteVectorValues(String field) throws IOException {
+ ensureOpen();
+ FieldInfo fi = getFieldInfos().fieldInfo(field);
+ if (fi == null
+ || fi.getVectorDimension() == 0
+ || fi.getVectorEncoding() != VectorEncoding.BYTE) {
+ // Field does not exist or does not index vectors
+ return null;
+ }
+
+ return getVectorReader().getByteVectorValues(field);
+ }
+
@Override
public final TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
diff --git a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java
index 99ce3bcd980..40ddf48cbbb 100644
--- a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java
@@ -53,6 +53,11 @@ abstract class DocValuesLeafReader extends LeafReader {
throw new UnsupportedOperationException();
}
+ @Override
+ public final ByteVectorValues getByteVectorValues(String field) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
index cfde6605b08..7e36380a501 100644
--- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
@@ -323,6 +323,15 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
return new ExitableVectorValues(vectorValues);
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ final ByteVectorValues vectorValues = in.getByteVectorValues(field);
+ if (vectorValues == null) {
+ return null;
+ }
+ return new ExitableByteVectorValues(vectorValues);
+ }
+
@Override
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
@@ -387,17 +396,18 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
}
}
- private class ExitableVectorValues extends FilterVectorValues {
+ private class ExitableVectorValues extends VectorValues {
private int docToCheck;
+ private final VectorValues vectorValues;
public ExitableVectorValues(VectorValues vectorValues) {
- super(vectorValues);
+ this.vectorValues = vectorValues;
docToCheck = 0;
}
@Override
public int advance(int target) throws IOException {
- final int advance = super.advance(target);
+ final int advance = vectorValues.advance(target);
if (advance >= docToCheck) {
checkAndThrow();
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
@@ -405,9 +415,14 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
return advance;
}
+ @Override
+ public int docID() {
+ return vectorValues.docID();
+ }
+
@Override
public int nextDoc() throws IOException {
- final int nextDoc = super.nextDoc();
+ final int nextDoc = vectorValues.nextDoc();
if (nextDoc >= docToCheck) {
checkAndThrow();
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
@@ -415,14 +430,91 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
return nextDoc;
}
+ @Override
+ public int dimension() {
+ return vectorValues.dimension();
+ }
+
@Override
public float[] vectorValue() throws IOException {
- return in.vectorValue();
+ return vectorValues.vectorValue();
+ }
+
+ @Override
+ public int size() {
+ return vectorValues.size();
}
@Override
public BytesRef binaryValue() throws IOException {
- return in.binaryValue();
+ return vectorValues.binaryValue();
+ }
+
+ /**
+ * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
+ * if {@link Thread#interrupted()} returns true.
+ */
+ private void checkAndThrow() {
+ if (queryTimeout.shouldExit()) {
+ throw new ExitingReaderException(
+ "The request took too long to iterate over vector values. Timeout: "
+ + queryTimeout.toString()
+ + ", VectorValues="
+ + in);
+ } else if (Thread.interrupted()) {
+ throw new ExitingReaderException(
+ "Interrupted while iterating over vector values. VectorValues=" + in);
+ }
+ }
+ }
+
+ private class ExitableByteVectorValues extends ByteVectorValues {
+ private int docToCheck;
+ private final ByteVectorValues vectorValues;
+
+ public ExitableByteVectorValues(ByteVectorValues vectorValues) {
+ this.vectorValues = vectorValues;
+ docToCheck = 0;
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ final int advance = vectorValues.advance(target);
+ if (advance >= docToCheck) {
+ checkAndThrow();
+ docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
+ }
+ return advance;
+ }
+
+ @Override
+ public int docID() {
+ return vectorValues.docID();
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ final int nextDoc = vectorValues.nextDoc();
+ if (nextDoc >= docToCheck) {
+ checkAndThrow();
+ docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
+ }
+ return nextDoc;
+ }
+
+ @Override
+ public int dimension() {
+ return vectorValues.dimension();
+ }
+
+ @Override
+ public int size() {
+ return vectorValues.size();
+ }
+
+ @Override
+ public BytesRef vectorValue() throws IOException {
+ return vectorValues.vectorValue();
}
/**
diff --git a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
index cedec13a1d0..e791772de26 100644
--- a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
@@ -351,6 +351,11 @@ public abstract class FilterLeafReader extends LeafReader {
return in.getVectorValues(field);
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ return in.getByteVectorValues(field);
+ }
+
@Override
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
diff --git a/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java b/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java
index 14d55b36252..c2a46c195f3 100644
--- a/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java
+++ b/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java
@@ -38,6 +38,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsFormat;
import org.apache.lucene.codecs.PointsWriter;
import org.apache.lucene.document.FieldType;
+import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Sort;
@@ -721,15 +722,7 @@ final class IndexingChain implements Accountable {
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
}
if (fieldType.vectorDimension() != 0) {
- switch (fieldType.vectorEncoding()) {
- case BYTE:
- pf.knnFieldVectorsWriter.addValue(docID, field.binaryValue());
- break;
- default:
- case FLOAT32:
- pf.knnFieldVectorsWriter.addValue(docID, ((KnnVectorField) field).vectorValue());
- break;
- }
+ indexVectorValue(docID, pf, fieldType.vectorEncoding(), field);
}
return indexedField;
}
@@ -963,6 +956,24 @@ final class IndexingChain implements Accountable {
}
}
+ @SuppressWarnings("unchecked")
+ private void indexVectorValue(
+ int docID, PerField pf, VectorEncoding vectorEncoding, IndexableField field)
+ throws IOException {
+ switch (vectorEncoding) {
+ case BYTE:
+ ((KnnFieldVectorsWriter<BytesRef>) pf.knnFieldVectorsWriter)
+ .addValue(docID, ((KnnByteVectorField) field).vectorValue());
+ break;
+ case FLOAT32:
+ ((KnnFieldVectorsWriter<float[]>) pf.knnFieldVectorsWriter)
+ .addValue(docID, ((KnnVectorField) field).vectorValue());
+ break;
+ default:
+ throw new IllegalArgumentException("unknown vector encoding=" + vectorEncoding);
+ }
+ }
+
/** Returns a previously created {@link PerField}, or null if this field name wasn't seen yet. */
private PerField getPerField(String name) {
final int hashPos = name.hashCode() & hashMask;
diff --git a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java
index e71ff1f3edf..7e500480ac3 100644
--- a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java
@@ -208,6 +208,14 @@ public abstract class LeafReader extends IndexReader {
*/
public abstract VectorValues getVectorValues(String field) throws IOException;
+ /**
+ * Returns {@link ByteVectorValues} for this field, or null if no {@link ByteVectorValues} were
+ * indexed. The returned instance should only be used by a single thread.
+ *
+ * @lucene.experimental
+ */
+ public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
+
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document
diff --git a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java
index 8f13f91ea28..0d63636ddb7 100644
--- a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java
@@ -442,6 +442,13 @@ public class ParallelLeafReader extends LeafReader {
return reader == null ? null : reader.getVectorValues(fieldName);
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String fieldName) throws IOException {
+ ensureOpen();
+ LeafReader reader = fieldToReader.get(fieldName);
+ return reader == null ? null : reader.getByteVectorValues(fieldName);
+ }
+
@Override
public TopDocs searchNearestVectors(
String fieldName, float[] target, int k, Bits acceptDocs, int visitedLimit)
diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java
index 912bb54d5cf..4d664f4c759 100644
--- a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java
+++ b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java
@@ -168,6 +168,11 @@ public final class SlowCodecReaderWrapper {
return reader.getVectorValues(field);
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ return reader.getByteVectorValues(field);
+ }
+
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
index 1444c97c7d1..3fae5973855 100644
--- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
@@ -222,34 +222,21 @@ public final class SortingCodecReader extends FilterCodecReader {
final FixedBitSet docsWithField;
final float[][] vectors;
final ByteBuffer vectorAsBytes;
- final BytesRef[] binaryVectors;
private int docId = -1;
- SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap, VectorEncoding encoding)
- throws IOException {
+ SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.size = delegate.size();
this.dimension = delegate.dimension();
docsWithField = new FixedBitSet(sortMap.size());
- if (encoding == VectorEncoding.BYTE) {
- vectors = null;
- binaryVectors = new BytesRef[sortMap.size()];
- vectorAsBytes = null;
- } else {
- vectors = new float[sortMap.size()][];
- binaryVectors = null;
- vectorAsBytes =
- ByteBuffer.allocate(delegate.dimension() * encoding.byteSize)
- .order(ByteOrder.LITTLE_ENDIAN);
- }
+ vectors = new float[sortMap.size()][];
+ vectorAsBytes =
+ ByteBuffer.allocate(delegate.dimension() * VectorEncoding.FLOAT32.byteSize)
+ .order(ByteOrder.LITTLE_ENDIAN);
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
int newDocID = sortMap.oldToNew(doc);
docsWithField.set(newDocID);
- if (encoding == VectorEncoding.BYTE) {
- binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.binaryValue());
- } else {
- vectors[newDocID] = delegate.vectorValue().clone();
- }
+ vectors[newDocID] = delegate.vectorValue().clone();
}
}
@@ -265,12 +252,8 @@ public final class SortingCodecReader extends FilterCodecReader {
@Override
public BytesRef binaryValue() throws IOException {
- if (binaryVectors != null) {
- return binaryVectors[docId];
- } else {
- vectorAsBytes.asFloatBuffer().put(vectors[docId]);
- return new BytesRef(vectorAsBytes.array());
- }
+ vectorAsBytes.asFloatBuffer().put(vectors[docId]);
+ return new BytesRef(vectorAsBytes.array());
}
@Override
@@ -297,6 +280,60 @@ public final class SortingCodecReader extends FilterCodecReader {
}
}
+ private static class SortingByteVectorValues extends ByteVectorValues {
+ final int size;
+ final int dimension;
+ final FixedBitSet docsWithField;
+ final BytesRef[] binaryVectors;
+
+ private int docId = -1;
+
+ SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
+ this.size = delegate.size();
+ this.dimension = delegate.dimension();
+ docsWithField = new FixedBitSet(sortMap.size());
+ binaryVectors = new BytesRef[sortMap.size()];
+ for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
+ int newDocID = sortMap.oldToNew(doc);
+ docsWithField.set(newDocID);
+ binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.vectorValue());
+ }
+ }
+
+ @Override
+ public int docID() {
+ return docId;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ return advance(docId + 1);
+ }
+
+ @Override
+ public BytesRef vectorValue() throws IOException {
+ return binaryVectors[docId];
+ }
+
+ @Override
+ public int dimension() {
+ return dimension;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ if (target >= docsWithField.length()) {
+ return NO_MORE_DOCS;
+ }
+ return docId = docsWithField.nextSetBit(target);
+ }
+ }
+
/**
* Return a sorted view of <code>reader</code> according to the order defined by <code>sort</code>
* . If the reader is already sorted, this method might return the reader as-is.
@@ -465,9 +502,12 @@ public final class SortingCodecReader extends FilterCodecReader {
@Override
public VectorValues getVectorValues(String field) throws IOException {
- FieldInfo fi = in.getFieldInfos().fieldInfo(field);
- return new SortingVectorValues(
- delegate.getVectorValues(field), docMap, fi.getVectorEncoding());
+ return new SortingVectorValues(delegate.getVectorValues(field), docMap);
+ }
+
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ return new SortingByteVectorValues(delegate.getByteVectorValues(field), docMap);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorValues.java b/lucene/core/src/java/org/apache/lucene/index/VectorValues.java
index 79b1033b4aa..549fa6ef55b 100644
--- a/lucene/core/src/java/org/apache/lucene/index/VectorValues.java
+++ b/lucene/core/src/java/org/apache/lucene/index/VectorValues.java
@@ -61,8 +61,8 @@ public abstract class VectorValues extends DocIdSetIterator {
/**
* Return the binary encoded vector value for the current document ID. These are the bytes
- * corresponding to the float array return by {@link #vectorValue}. It is illegal to call this
- * method when the iterator is not positioned: before advancing, or after failing to advance. The
+ * corresponding to the array return by {@link #vectorValue}. It is illegal to call this method
+ * when the iterator is not positioned: before advancing, or after failing to advance. The
* returned storage may be shared across calls, re-used and modified as the iterator advances.
*
* @return the binary value
diff --git a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java
index eb3320e0a44..afaf59751b7 100644
--- a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java
@@ -31,7 +31,8 @@ import org.apache.lucene.index.Terms;
/**
* A {@link Query} that matches documents that contain either a {@link
- * org.apache.lucene.document.KnnVectorField}, or a field that indexes norms or doc values.
+ * org.apache.lucene.document.KnnVectorField}, {@link org.apache.lucene.document.KnnByteVectorField}
+ * or a field that indexes norms or doc values.
*/
public class FieldExistsQuery extends Query {
private String field;
@@ -126,7 +127,19 @@ public class FieldExistsQuery extends Query {
break;
}
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
- if (leaf.getVectorValues(field).size() != leaf.maxDoc()) {
+ final int numVectors;
+ switch (fieldInfo.getVectorEncoding()) {
+ case FLOAT32:
+ numVectors = leaf.getVectorValues(field).size();
+ break;
+ case BYTE:
+ numVectors = leaf.getByteVectorValues(field).size();
+ break;
+ default:
+ throw new IllegalArgumentException(
+ "unknown vector encoding=" + fieldInfo.getVectorEncoding());
+ }
+ if (numVectors != leaf.maxDoc()) {
allReadersRewritable = false;
break;
}
@@ -178,7 +191,18 @@ public class FieldExistsQuery extends Query {
if (fieldInfo.hasNorms()) { // the field indexes norms
iterator = context.reader().getNormValues(field);
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
- iterator = context.reader().getVectorValues(field);
+ switch (fieldInfo.getVectorEncoding()) {
+ case FLOAT32:
+ iterator = context.reader().getVectorValues(field);
+ break;
+ case BYTE:
+ iterator = context.reader().getByteVectorValues(field);
+ break;
+ default:
+ throw new IllegalArgumentException(
+ "unknown vector encoding=" + fieldInfo.getVectorEncoding());
+ }
+ ;
} else if (fieldInfo.getDocValuesType()
!= DocValuesType.NONE) { // the field indexes doc values
switch (fieldInfo.getDocValuesType()) {
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
index 0a36be780ae..87cbfde97e4 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
@@ -54,7 +54,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
* @param k the number of documents to find
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
- public KnnByteVectorQuery(String field, byte[] target, int k) {
+ public KnnByteVectorQuery(String field, BytesRef target, int k) {
this(field, target, k, null);
}
@@ -68,9 +68,9 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
* @param filter a filter applied before the vector search
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
- public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
+ public KnnByteVectorQuery(String field, BytesRef target, int k, Query filter) {
super(field, k, filter);
- this.target = new BytesRef(target);
+ this.target = target;
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java b/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java
index eadcdf536b6..a06295a6296 100644
--- a/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java
@@ -17,6 +17,7 @@
package org.apache.lucene.search;
import java.io.IOException;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
@@ -29,7 +30,6 @@ import org.apache.lucene.util.BytesRef;
* search over the vectors.
*/
abstract class VectorScorer {
- protected final VectorValues values;
protected final VectorSimilarityFunction similarity;
/**
@@ -48,53 +48,72 @@ abstract class VectorScorer {
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query)
throws IOException {
- VectorValues values = context.reader().getVectorValues(fi.name);
+ ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new ByteVectorScorer(values, query, similarity);
}
- VectorScorer(VectorValues values, VectorSimilarityFunction similarity) {
- this.values = values;
+ VectorScorer(VectorSimilarityFunction similarity) {
this.similarity = similarity;
}
- /**
- * Advance the instance to the given document ID and return true if there is a value for that
- * document.
- */
- public boolean advanceExact(int doc) throws IOException {
- int vectorDoc = values.docID();
- if (vectorDoc < doc) {
- vectorDoc = values.advance(doc);
- }
- return vectorDoc == doc;
- }
-
/** Compute the similarity score for the current document. */
abstract float score() throws IOException;
+ abstract boolean advanceExact(int doc) throws IOException;
+
private static class ByteVectorScorer extends VectorScorer {
private final BytesRef query;
+ private final ByteVectorValues values;
protected ByteVectorScorer(
- VectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
- super(values, similarity);
+ ByteVectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
+ super(similarity);
+ this.values = values;
this.query = query;
}
+ /**
+ * Advance the instance to the given document ID and return true if there is a value for that
+ * document.
+ */
+ @Override
+ public boolean advanceExact(int doc) throws IOException {
+ int vectorDoc = values.docID();
+ if (vectorDoc < doc) {
+ vectorDoc = values.advance(doc);
+ }
+ return vectorDoc == doc;
+ }
+
@Override
public float score() throws IOException {
- return similarity.compare(query, values.binaryValue());
+ return similarity.compare(query, values.vectorValue());
}
}
private static class FloatVectorScorer extends VectorScorer {
private final float[] query;
+ private final VectorValues values;
protected FloatVectorScorer(
VectorValues values, float[] query, VectorSimilarityFunction similarity) {
- super(values, similarity);
+ super(similarity);
this.query = query;
+ this.values = values;
+ }
+
+ /**
+ * Advance the instance to the given document ID and return true if there is a value for that
+ * document.
+ */
+ @Override
+ public boolean advanceExact(int doc) throws IOException {
+ int vectorDoc = values.docID();
+ if (vectorDoc < doc) {
+ vectorDoc = values.advance(doc);
+ }
+ return vectorDoc == doc;
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index d69c7f6b31b..bcbfc1c2fe2 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -53,7 +53,7 @@ public final class HnswGraphBuilder<T> {
private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
- private final RandomAccessVectorValues vectors;
+ private final RandomAccessVectorValues<T> vectors;
private final SplittableRandom random;
private final HnswGraphSearcher<T> graphSearcher;
@@ -63,10 +63,10 @@ public final class HnswGraphBuilder<T> {
// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
- private final RandomAccessVectorValues vectorsCopy;
+ private final RandomAccessVectorValues<T> vectorsCopy;
- public static HnswGraphBuilder<?> create(
- RandomAccessVectorValues vectors,
+ public static <T> HnswGraphBuilder<T> create(
+ RandomAccessVectorValues<T> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
int M,
@@ -89,7 +89,7 @@ public final class HnswGraphBuilder<T> {
* to ensure repeatable construction.
*/
private HnswGraphBuilder(
- RandomAccessVectorValues vectors,
+ RandomAccessVectorValues<T> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
int M,
@@ -131,7 +131,7 @@ public final class HnswGraphBuilder<T> {
* @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an
* independent accessor for the vectors
*/
- public OnHeapHnswGraph build(RandomAccessVectorValues vectorsToAdd) throws IOException {
+ public OnHeapHnswGraph build(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
if (vectorsToAdd == this.vectors) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@@ -143,7 +143,7 @@ public final class HnswGraphBuilder<T> {
return hnsw;
}
- private void addVectors(RandomAccessVectorValues vectorsToAdd) throws IOException {
+ private void addVectors(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
long start = System.nanoTime(), t = start;
// start at node 1! node 0 is added implicitly, in the constructor
for (int node = 1; node < vectorsToAdd.size(); node++) {
@@ -189,19 +189,8 @@ public final class HnswGraphBuilder<T> {
}
}
- public void addGraphNode(int node, RandomAccessVectorValues values) throws IOException {
- addGraphNode(node, getValue(node, values));
- }
-
- @SuppressWarnings("unchecked")
- private T getValue(int node, RandomAccessVectorValues values) throws IOException {
- switch (vectorEncoding) {
- case BYTE:
- return (T) values.binaryValue(node);
- default:
- case FLOAT32:
- return (T) values.vectorValue(node);
- }
+ public void addGraphNode(int node, RandomAccessVectorValues<T> values) throws IOException {
+ addGraphNode(node, values.vectorValue(node));
}
private long printGraphBuildStatus(int node, long start, long t) {
@@ -285,10 +274,10 @@ public final class HnswGraphBuilder<T> {
throws IOException {
switch (vectorEncoding) {
case BYTE:
- return isDiverse(vectors.binaryValue(candidate), neighbors, score);
+ return isDiverse((BytesRef) vectors.vectorValue(candidate), neighbors, score);
default:
case FLOAT32:
- return isDiverse(vectors.vectorValue(candidate), neighbors, score);
+ return isDiverse((float[]) vectors.vectorValue(candidate), neighbors, score);
}
}
@@ -296,7 +285,8 @@ public final class HnswGraphBuilder<T> {
throws IOException {
for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity =
- similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
+ similarityFunction.compare(
+ candidate, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
if (neighborSimilarity >= score) {
return false;
}
@@ -308,7 +298,8 @@ public final class HnswGraphBuilder<T> {
throws IOException {
for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity =
- similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
+ similarityFunction.compare(
+ candidate, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
if (neighborSimilarity >= score) {
return false;
}
@@ -334,10 +325,12 @@ public final class HnswGraphBuilder<T> {
int candidateNode = neighbors.node[candidateIndex];
switch (vectorEncoding) {
case BYTE:
- return isWorstNonDiverse(candidateIndex, vectors.binaryValue(candidateNode), neighbors);
+ return isWorstNonDiverse(
+ candidateIndex, (BytesRef) vectors.vectorValue(candidateNode), neighbors);
default:
case FLOAT32:
- return isWorstNonDiverse(candidateIndex, vectors.vectorValue(candidateNode), neighbors);
+ return isWorstNonDiverse(
+ candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
}
}
@@ -346,7 +339,8 @@ public final class HnswGraphBuilder<T> {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
- similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i]));
+ similarityFunction.compare(
+ candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
@@ -360,7 +354,8 @@ public final class HnswGraphBuilder<T> {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
- similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i]));
+ similarityFunction.compare(
+ candidateVector, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index a2650f6e5d9..e058b41f67a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -81,7 +81,7 @@ public class HnswGraphSearcher<T> {
public static NeighborQueue search(
float[] query,
int topK,
- RandomAccessVectorValues vectors,
+ RandomAccessVectorValues<float[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
@@ -137,7 +137,7 @@ public class HnswGraphSearcher<T> {
public static NeighborQueue search(
BytesRef query,
int topK,
- RandomAccessVectorValues vectors,
+ RandomAccessVectorValues<BytesRef> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
@@ -198,7 +198,7 @@ public class HnswGraphSearcher<T> {
int topK,
int level,
final int[] eps,
- RandomAccessVectorValues vectors,
+ RandomAccessVectorValues<T> vectors,
HnswGraph graph)
throws IOException {
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
@@ -209,7 +209,7 @@ public class HnswGraphSearcher<T> {
int topK,
int level,
final int[] eps,
- RandomAccessVectorValues vectors,
+ RandomAccessVectorValues<T> vectors,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
@@ -279,11 +279,11 @@ public class HnswGraphSearcher<T> {
return results;
}
- private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException {
+ private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException {
if (vectorEncoding == VectorEncoding.BYTE) {
- return similarityFunction.compare((BytesRef) query, vectors.binaryValue(ord));
+ return similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(ord));
} else {
- return similarityFunction.compare((float[]) query, vectors.vectorValue(ord));
+ return similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(ord));
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java
index 4c220518a2e..956749e678e 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java
@@ -18,7 +18,6 @@
package org.apache.lucene.util.hnsw;
import java.io.IOException;
-import org.apache.lucene.util.BytesRef;
/**
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
@@ -26,7 +25,7 @@ import org.apache.lucene.util.BytesRef;
*
* @lucene.experimental
*/
-public interface RandomAccessVectorValues {
+public interface RandomAccessVectorValues<T> {
/** Return the number of vector values */
int size();
@@ -35,26 +34,16 @@ public interface RandomAccessVectorValues {
int dimension();
/**
- * Return the vector value indexed at the given ordinal. The provided floating point array may be
- * shared and overwritten by subsequent calls to this method and {@link #binaryValue(int)}.
+ * Return the vector value indexed at the given ordinal.
*
* @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}.
*/
- float[] vectorValue(int targetOrd) throws IOException;
-
- /**
- * Return the vector indexed at the given ordinal value as an array of bytes in a BytesRef; these
- * are the bytes corresponding to the float array. The provided bytes may be shared and
- * overwritten by subsequent calls to this method and {@link #vectorValue(int)}.
- *
- * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}.
- */
- BytesRef binaryValue(int targetOrd) throws IOException;
+ T vectorValue(int targetOrd) throws IOException;
/**
* Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to
* access different values at once, to avoid overwriting the underlying float vector returned by
* {@link RandomAccessVectorValues#vectorValue}.
*/
- RandomAccessVectorValues copy() throws IOException;
+ RandomAccessVectorValues<T> copy() throws IOException;
}
diff --git a/lucene/core/src/test/org/apache/lucene/document/TestField.java b/lucene/core/src/test/org/apache/lucene/document/TestField.java
index da0b65bc897..534f06ff7a6 100644
--- a/lucene/core/src/test/org/apache/lucene/document/TestField.java
+++ b/lucene/core/src/test/org/apache/lucene/document/TestField.java
@@ -21,6 +21,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
@@ -611,25 +612,22 @@ public class TestField extends LuceneTestCase {
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
BytesRef br = newBytesRef(new byte[5]);
- Field field = new KnnVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
+ Field field = new KnnByteVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
expectThrows(
IllegalArgumentException.class,
() -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType()));
float[] vector = new float[] {1, 2};
Field field2 = new KnnVectorField("float", vector);
- expectThrows(
- IllegalArgumentException.class,
- () -> new KnnVectorField("bogus", br, (FieldType) field2.fieldType()));
assertEquals(br, field.binaryValue());
doc.add(field);
doc.add(field2);
w.addDocument(doc);
try (IndexReader r = DirectoryReader.open(w)) {
- VectorValues binary = r.leaves().get(0).reader().getVectorValues("binary");
+ ByteVectorValues binary = r.leaves().get(0).reader().getByteVectorValues("binary");
assertEquals(1, binary.size());
assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
- assertEquals(br, binary.binaryValue());
assertNotNull(binary.vectorValue());
+ assertEquals(br, binary.vectorValue());
assertEquals(NO_MORE_DOCS, binary.nextDoc());
VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float");
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java
index 279456a442e..57b2435f32d 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java
@@ -117,6 +117,11 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
return null;
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) {
+ return null;
+ }
+
@Override
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
index f5037af9177..f3cbfb01e16 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java
@@ -17,7 +17,7 @@
package org.apache.lucene.search;
import org.apache.lucene.document.Field;
-import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@@ -27,12 +27,12 @@ import org.apache.lucene.util.TestVectorUtil;
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
@Override
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
- return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
+ return new KnnByteVectorQuery(field, new BytesRef(floatToBytes(query)), k, queryFilter);
}
@Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
- return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query);
+ return new ThrowingKnnVectorQuery(field, new BytesRef(floatToBytes(vec)), k, query);
}
@Override
@@ -49,12 +49,12 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
@Override
Field getKnnVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
- return new KnnVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
+ return new KnnByteVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
}
@Override
Field getKnnVectorField(String name, float[] vector) {
- return new KnnVectorField(
+ return new KnnByteVectorField(
name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN);
}
@@ -80,7 +80,7 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
- public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
+ public ThrowingKnnVectorQuery(String field, BytesRef target, int k, Query filter) {
super(field, target, k, filter);
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java
index 2c72baa73d3..99ea200a16d 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestVectorScorer.java
@@ -22,6 +22,7 @@ import com.carrotsearch.randomizedtesting.generators.RandomPicks;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
+import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
@@ -79,7 +80,7 @@ public class TestVectorScorer extends LuceneTestCase {
for (int j = 0; j < v.length; j++) {
v.bytes[j] = (byte) contents[i][j];
}
- doc.add(new KnnVectorField(field, v, EUCLIDEAN));
+ doc.add(new KnnByteVectorField(field, v, EUCLIDEAN));
} else {
doc.add(new KnnVectorField(field, contents[i]));
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java
similarity index 52%
copy from lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java
copy to lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java
index 299d27857da..166ee00dcf7 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java
@@ -17,42 +17,29 @@
package org.apache.lucene.util.hnsw;
-import org.apache.lucene.index.VectorValues;
-import org.apache.lucene.tests.util.LuceneTestCase;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
import org.apache.lucene.util.BytesRef;
-class MockVectorValues extends VectorValues implements RandomAccessVectorValues {
- private final float[] scratch;
+abstract class AbstractMockVectorValues<T> implements RandomAccessVectorValues<T> {
protected final int dimension;
- protected final float[][] denseValues;
- protected final float[][] values;
- private final int numVectors;
- private final BytesRef binaryValue;
+ protected final T[] denseValues;
+ protected final T[] values;
+ protected final int numVectors;
+ protected final BytesRef binaryValue;
- private int pos = -1;
+ protected int pos = -1;
- MockVectorValues(float[][] values) {
- this.dimension = values[0].length;
+ AbstractMockVectorValues(T[] values, int dimension, T[] denseValues, int numVectors) {
+ this.dimension = dimension;
this.values = values;
- int maxDoc = values.length;
- denseValues = new float[maxDoc][];
- int count = 0;
- for (int i = 0; i < maxDoc; i++) {
- if (values[i] != null) {
- denseValues[count++] = values[i];
- }
- }
- numVectors = count;
- scratch = new float[dimension];
+ this.denseValues = denseValues;
// used by tests that build a graph from bytes rather than floats
binaryValue = new BytesRef(dimension);
binaryValue.length = dimension;
- }
-
- @Override
- public MockVectorValues copy() {
- return new MockVectorValues(values);
+ this.numVectors = numVectors;
}
@Override
@@ -66,32 +53,14 @@ class MockVectorValues extends VectorValues implements RandomAccessVectorValues
}
@Override
- public float[] vectorValue() {
- if (LuceneTestCase.random().nextBoolean()) {
- return values[pos];
- } else {
- // Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
- // This should help us catch cases of aliasing where the same VectorValues source is used
- // twice in a
- // single computation.
- System.arraycopy(values[pos], 0, scratch, 0, dimension);
- return scratch;
- }
- }
-
- @Override
- public float[] vectorValue(int targetOrd) {
+ public T vectorValue(int targetOrd) {
return denseValues[targetOrd];
}
@Override
- public BytesRef binaryValue(int targetOrd) {
- float[] value = vectorValue(targetOrd);
- for (int i = 0; i < value.length; i++) {
- binaryValue.bytes[i] = (byte) value[i];
- }
- return binaryValue;
- }
+ public abstract AbstractMockVectorValues<T> copy();
+
+ public abstract T vectorValue() throws IOException;
private boolean seek(int target) {
if (target >= 0 && target < values.length && values[target] != null) {
@@ -102,17 +71,14 @@ class MockVectorValues extends VectorValues implements RandomAccessVectorValues
}
}
- @Override
public int docID() {
return pos;
}
- @Override
public int nextDoc() {
return advance(pos + 1);
}
- @Override
public int advance(int target) {
while (++pos < values.length) {
if (seek(pos)) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
similarity index 73%
rename from lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
rename to lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
index e3cfa2d462b..93306046b69 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
@@ -20,7 +20,6 @@ package org.apache.lucene.util.hnsw;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.tests.util.RamUsageTester.ramUsed;
-import static org.apache.lucene.util.VectorUtil.toBytesRef;
import com.carrotsearch.randomizedtesting.RandomizedTest;
import java.io.IOException;
@@ -36,21 +35,23 @@ import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
-import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.document.Field;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StoredField;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.IndexSearcher;
-import org.apache.lucene.search.KnnVectorQuery;
+import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
@@ -65,19 +66,30 @@ import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
-import org.junit.Before;
/** Tests HNSW KNN graphs */
-public class TestHnswGraph extends LuceneTestCase {
+abstract class HnswGraphTestCase<T> extends LuceneTestCase {
VectorSimilarityFunction similarityFunction;
- VectorEncoding vectorEncoding;
- @Before
- public void setup() {
- similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
- vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
- }
+ abstract VectorEncoding getVectorEncoding();
+
+ abstract Query knnQuery(String field, T vector, int k);
+
+ abstract T randomVector(int dim);
+
+ abstract AbstractMockVectorValues<T> vectorValues(int size, int dimension);
+
+ abstract AbstractMockVectorValues<T> vectorValues(float[][] values);
+
+ abstract AbstractMockVectorValues<T> vectorValues(LeafReader reader, String fieldName)
+ throws IOException;
+
+ abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction);
+
+ abstract RandomAccessVectorValues<T> circularVectorValues(int nDoc);
+
+ abstract T getTargetVector();
// test writing out and reading in a graph gives the expected graph
public void testReadWrite() throws IOException {
@@ -86,10 +98,11 @@ public class TestHnswGraph extends LuceneTestCase {
int M = random().nextInt(4) + 2;
int beamWidth = random().nextInt(10) + 5;
long seed = random().nextLong();
- RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, vectorEncoding, random());
- RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
- HnswGraphBuilder<?> builder =
- HnswGraphBuilder.create(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
+ AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
+ AbstractMockVectorValues<T> v2 = vectors.copy(), v3 = vectors.copy();
+ HnswGraphBuilder<T> builder =
+ HnswGraphBuilder.create(
+ vectors, getVectorEncoding(), similarityFunction, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors.copy());
// Recreate the graph while indexing with the same random seed and write it out
@@ -115,7 +128,7 @@ public class TestHnswGraph extends LuceneTestCase {
indexedDoc++;
}
Document doc = new Document();
- doc.add(new KnnVectorField("field", v2.vectorValue(), similarityFunction));
+ doc.add(knnVectorField("field", v2.vectorValue(), similarityFunction));
doc.add(new StoredField("id", v2.docID()));
iw.addDocument(doc);
nVec++;
@@ -124,7 +137,7 @@ public class TestHnswGraph extends LuceneTestCase {
}
try (IndexReader reader = DirectoryReader.open(dir)) {
for (LeafReaderContext ctx : reader.leaves()) {
- VectorValues values = ctx.reader().getVectorValues("field");
+ AbstractMockVectorValues<T> values = vectorValues(ctx.reader(), "field");
assertEquals(dim, values.dimension());
assertEquals(nVec, values.size());
assertEquals(indexedDoc, ctx.reader().maxDoc());
@@ -142,15 +155,11 @@ public class TestHnswGraph extends LuceneTestCase {
}
}
- private VectorEncoding randomVectorEncoding() {
- return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
- }
-
// test that sorted index returns the same search results are unsorted
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
int dim = random().nextInt(10) + 3;
int nDoc = random().nextInt(200) + 100;
- RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
+ AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
int M = random().nextInt(10) + 5;
int beamWidth = random().nextInt(10) + 5;
@@ -190,7 +199,7 @@ public class TestHnswGraph extends LuceneTestCase {
indexedDoc++;
}
Document doc = new Document();
- doc.add(new KnnVectorField("vector", vectors.vectorValue(), similarityFunction));
+ doc.add(knnVectorField("vector", vectors.vectorValue(), similarityFunction));
doc.add(new StoredField("id", vectors.docID()));
doc.add(new NumericDocValuesField("sortkey", random().nextLong()));
iw.addDocument(doc);
@@ -206,7 +215,7 @@ public class TestHnswGraph extends LuceneTestCase {
for (int i = 0; i < 10; i++) {
// ask to explore a lot of candidates to ensure the same returned hits,
// as graphs of 2 indices are organized differently
- KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(random(), dim), 50);
+ Query query = knnQuery("vector", randomVector(dim), 50);
List<String> ids1 = new ArrayList<>();
List<Integer> docs1 = new ArrayList<>();
List<String> ids2 = new ArrayList<>();
@@ -241,7 +250,7 @@ public class TestHnswGraph extends LuceneTestCase {
}
}
- private void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
+ void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
@@ -271,25 +280,25 @@ public class TestHnswGraph extends LuceneTestCase {
// Make sure we actually approximately find the closest k elements. Mostly this is about
// ensuring that we have all the distance functions, comparators, priority queues and so on
// oriented in the right directions
+ @SuppressWarnings("unchecked")
public void testAknnDiverse() throws IOException {
int nDoc = 100;
- vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
- CircularVectorValues vectors = new CircularVectorValues(nDoc);
- HnswGraphBuilder<?> builder =
+ RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
+ HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
- vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt());
+ vectors, getVectorEncoding(), similarityFunction, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// run some searches
final NeighborQueue nn;
- switch (vectorEncoding) {
+ switch (getVectorEncoding()) {
case FLOAT32:
nn =
HnswGraphSearcher.search(
- getTargetVector(),
+ (float[]) getTargetVector(),
10,
- vectors.copy(),
- vectorEncoding,
+ (RandomAccessVectorValues<float[]>) vectors.copy(),
+ getVectorEncoding(),
similarityFunction,
hnsw,
null,
@@ -298,19 +307,18 @@ public class TestHnswGraph extends LuceneTestCase {
case BYTE:
nn =
HnswGraphSearcher.search(
- getTargetByteVector(),
+ (BytesRef) getTargetVector(),
10,
- vectors.copy(),
- vectorEncoding,
+ (RandomAccessVectorValues<BytesRef>) vectors.copy(),
+ getVectorEncoding(),
similarityFunction,
hnsw,
null,
Integer.MAX_VALUE);
break;
default:
- throw new IllegalArgumentException("unexpected vector encoding: " + vectorEncoding);
+ throw new IllegalArgumentException("unexpected vector encoding: " + getVectorEncoding());
}
-
int[] nodes = nn.nodes();
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0;
@@ -331,26 +339,26 @@ public class TestHnswGraph extends LuceneTestCase {
}
}
+ @SuppressWarnings("unchecked")
public void testSearchWithAcceptOrds() throws IOException {
int nDoc = 100;
- CircularVectorValues vectors = new CircularVectorValues(nDoc);
+ RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
- vectorEncoding = randomVectorEncoding();
- HnswGraphBuilder<?> builder =
+ HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
- vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
+ vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// the first 10 docs must not be deleted to ensure the expected recall
- Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
+ Bits acceptOrds = createRandomAcceptOrds(10, nDoc);
final NeighborQueue nn;
- switch (vectorEncoding) {
+ switch (getVectorEncoding()) {
case FLOAT32:
nn =
HnswGraphSearcher.search(
- getTargetVector(),
+ (float[]) getTargetVector(),
10,
- vectors.copy(),
- vectorEncoding,
+ (RandomAccessVectorValues<float[]>) vectors.copy(),
+ getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
@@ -359,17 +367,17 @@ public class TestHnswGraph extends LuceneTestCase {
case BYTE:
nn =
HnswGraphSearcher.search(
- getTargetByteVector(),
+ (BytesRef) getTargetVector(),
10,
- vectors.copy(),
- vectorEncoding,
+ (RandomAccessVectorValues<BytesRef>) vectors.copy(),
+ getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
break;
default:
- throw new IllegalArgumentException("unexpected vector encoding: " + vectorEncoding);
+ throw new IllegalArgumentException("unexpected vector encoding: " + getVectorEncoding());
}
int[] nodes = nn.nodes();
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
@@ -383,32 +391,32 @@ public class TestHnswGraph extends LuceneTestCase {
assertTrue("sum(result docs)=" + sum, sum < 75);
}
+ @SuppressWarnings("unchecked")
public void testSearchWithSelectiveAcceptOrds() throws IOException {
int nDoc = 100;
- CircularVectorValues vectors = new CircularVectorValues(nDoc);
- vectorEncoding = randomVectorEncoding();
+ RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
- HnswGraphBuilder<?> builder =
+ HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
- vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
+ vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// Only mark a few vectors as accepted
- BitSet acceptOrds = new FixedBitSet(vectors.size);
- for (int i = 0; i < vectors.size; i += 15 + random().nextInt(5)) {
+ BitSet acceptOrds = new FixedBitSet(nDoc);
+ for (int i = 0; i < nDoc; i += 15 + random().nextInt(5)) {
acceptOrds.set(i);
}
// Check the search finds all accepted vectors
int numAccepted = acceptOrds.cardinality();
final NeighborQueue nn;
- switch (vectorEncoding) {
+ switch (getVectorEncoding()) {
case FLOAT32:
nn =
HnswGraphSearcher.search(
- getTargetVector(),
+ (float[]) getTargetVector(),
numAccepted,
- vectors.copy(),
- vectorEncoding,
+ (RandomAccessVectorValues<float[]>) vectors.copy(),
+ getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
@@ -417,17 +425,17 @@ public class TestHnswGraph extends LuceneTestCase {
case BYTE:
nn =
HnswGraphSearcher.search(
- getTargetByteVector(),
+ (BytesRef) getTargetVector(),
numAccepted,
- vectors.copy(),
- vectorEncoding,
+ (RandomAccessVectorValues<BytesRef>) vectors.copy(),
+ getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
break;
default:
- throw new IllegalArgumentException("unexpected vector encoding: " + vectorEncoding);
+ throw new IllegalArgumentException("unexpected vector encoding: " + getVectorEncoding());
}
int[] nodes = nn.nodes();
assertEquals(numAccepted, nodes.length);
@@ -436,92 +444,47 @@ public class TestHnswGraph extends LuceneTestCase {
}
}
- private float[] getTargetVector() {
- return new float[] {1, 0};
- }
-
- private BytesRef getTargetByteVector() {
- return new BytesRef(new byte[] {1, 0});
- }
-
- public void testSearchWithSkewedAcceptOrds() throws IOException {
- int nDoc = 1000;
- similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
- CircularVectorValues vectors = new CircularVectorValues(nDoc);
- HnswGraphBuilder<?> builder =
- HnswGraphBuilder.create(
- vectors, VectorEncoding.FLOAT32, similarityFunction, 16, 100, random().nextInt());
- OnHeapHnswGraph hnsw = builder.build(vectors.copy());
-
- // Skip over half of the documents that are closest to the query vector
- FixedBitSet acceptOrds = new FixedBitSet(nDoc);
- for (int i = 500; i < nDoc; i++) {
- acceptOrds.set(i);
- }
- NeighborQueue nn =
- HnswGraphSearcher.search(
- getTargetVector(),
- 10,
- vectors.copy(),
- VectorEncoding.FLOAT32,
- similarityFunction,
- hnsw,
- acceptOrds,
- Integer.MAX_VALUE);
- int[] nodes = nn.nodes();
- assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
- int sum = 0;
- for (int node : nodes) {
- assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
- sum += node;
- }
- // We still expect to get reasonable recall. The lowest non-skipped docIds
- // are closest to the query vector: sum(500,509) = 5045
- assertTrue("sum(result docs)=" + sum, sum < 5100);
- }
-
+ @SuppressWarnings("unchecked")
public void testVisitedLimit() throws IOException {
int nDoc = 500;
- vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
- CircularVectorValues vectors = new CircularVectorValues(nDoc);
- HnswGraphBuilder<?> builder =
+ RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
+ HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
- vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
+ vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
int topK = 50;
int visitedLimit = topK + random().nextInt(5);
final NeighborQueue nn;
- switch (vectorEncoding) {
+ switch (getVectorEncoding()) {
case FLOAT32:
nn =
HnswGraphSearcher.search(
- getTargetVector(),
+ (float[]) getTargetVector(),
topK,
- vectors.copy(),
- vectorEncoding,
+ (RandomAccessVectorValues<float[]>) vectors.copy(),
+ getVectorEncoding(),
similarityFunction,
hnsw,
- createRandomAcceptOrds(0, vectors.size),
+ createRandomAcceptOrds(0, nDoc),
visitedLimit);
break;
case BYTE:
nn =
HnswGraphSearcher.search(
- getTargetByteVector(),
+ (BytesRef) getTargetVector(),
topK,
- vectors.copy(),
- vectorEncoding,
+ (RandomAccessVectorValues<BytesRef>) vectors.copy(),
+ getVectorEncoding(),
similarityFunction,
hnsw,
- createRandomAcceptOrds(0, vectors.size),
+ createRandomAcceptOrds(0, nDoc),
visitedLimit);
break;
default:
- throw new IllegalArgumentException("unexpected vector encoding: " + vectorEncoding);
+ throw new IllegalArgumentException("unexpected vector encoding: " + getVectorEncoding());
}
-
assertTrue(nn.incomplete());
// The visited count shouldn't exceed the limit
assertTrue(nn.visitedCount() <= visitedLimit);
@@ -535,8 +498,8 @@ public class TestHnswGraph extends LuceneTestCase {
IllegalArgumentException.class,
() ->
HnswGraphBuilder.create(
- new RandomVectorValues(1, 1, random()),
- VectorEncoding.FLOAT32,
+ vectorValues(1, 1),
+ getVectorEncoding(),
VectorSimilarityFunction.EUCLIDEAN,
0,
10,
@@ -546,8 +509,8 @@ public class TestHnswGraph extends LuceneTestCase {
IllegalArgumentException.class,
() ->
HnswGraphBuilder.create(
- new RandomVectorValues(1, 1, random()),
- VectorEncoding.FLOAT32,
+ vectorValues(1, 1),
+ getVectorEncoding(),
VectorSimilarityFunction.EUCLIDEAN,
10,
0,
@@ -561,13 +524,11 @@ public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction similarityFunction =
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
- VectorEncoding vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
- TestHnswGraph.RandomVectorValues vectors =
- new TestHnswGraph.RandomVectorValues(size, dim, vectorEncoding, random());
+ RandomAccessVectorValues<T> vectors = vectorValues(size, dim);
- HnswGraphBuilder<?> builder =
+ HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
- vectors, vectorEncoding, similarityFunction, M, M * 2, random().nextLong());
+ vectors, getVectorEncoding(), similarityFunction, M, M * 2, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
long estimated = RamUsageEstimator.sizeOfObject(hnsw);
long actual = ramUsed(hnsw);
@@ -577,7 +538,6 @@ public class TestHnswGraph extends LuceneTestCase {
@SuppressWarnings("unchecked")
public void testDiversity() throws IOException {
- vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
// Some carefully checked test cases with simple 2d vectors on the unit circle:
float[][] values = {
@@ -589,21 +549,14 @@ public class TestHnswGraph extends LuceneTestCase {
unitVector2d(0.77),
unitVector2d(0.6)
};
- if (vectorEncoding == VectorEncoding.BYTE) {
- for (float[] v : values) {
- for (int i = 0; i < v.length; i++) {
- v[i] *= 127;
- }
- }
- }
- MockVectorValues vectors = new MockVectorValues(values);
+ AbstractMockVectorValues<T> vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list
- HnswGraphBuilder<?> builder =
+ HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
- vectors, vectorEncoding, similarityFunction, 2, 10, random().nextInt());
+ vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
- RandomAccessVectorValues vectorsCopy = vectors.copy();
+ RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
// now every node has tried to attach every other node as a neighbor, but
@@ -640,7 +593,6 @@ public class TestHnswGraph extends LuceneTestCase {
}
public void testDiversityFallback() throws IOException {
- vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
// Some test cases can't be exercised in two dimensions;
// in particular if a new neighbor displaces an existing neighbor
@@ -653,14 +605,14 @@ public class TestHnswGraph extends LuceneTestCase {
{10, 0, 0},
{0, 4, 0}
};
- MockVectorValues vectors = new MockVectorValues(values);
+ AbstractMockVectorValues<T> vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list
- HnswGraphBuilder<?> builder =
+ HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
- vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
+ vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
- RandomAccessVectorValues vectorsCopy = vectors.copy();
+ RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
@@ -678,7 +630,6 @@ public class TestHnswGraph extends LuceneTestCase {
}
public void testDiversity3d() throws IOException {
- vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
// test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
float[][] values = {
@@ -687,14 +638,14 @@ public class TestHnswGraph extends LuceneTestCase {
{0, 0, 20},
{0, 9, 0}
};
- MockVectorValues vectors = new MockVectorValues(values);
+ AbstractMockVectorValues<T> vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list
- HnswGraphBuilder<?> builder =
+ HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
- vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
+ vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
- RandomAccessVectorValues vectorsCopy = vectors.copy();
+ RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
@@ -722,36 +673,30 @@ public class TestHnswGraph extends LuceneTestCase {
actual);
}
+ @SuppressWarnings("unchecked")
public void testRandom() throws IOException {
int size = atLeast(100);
int dim = atLeast(10);
- RandomVectorValues vectors = new RandomVectorValues(size, dim, vectorEncoding, random());
+ AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
int topK = 5;
- HnswGraphBuilder<?> builder =
+ HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
- vectors, vectorEncoding, similarityFunction, 10, 30, random().nextLong());
+ vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
int totalMatches = 0;
for (int i = 0; i < 100; i++) {
- float[] query;
- BytesRef bQuery = null;
- if (vectorEncoding == VectorEncoding.BYTE) {
- query = randomVector8(random(), dim);
- bQuery = toBytesRef(query);
- } else {
- query = randomVector(random(), dim);
- }
final NeighborQueue actual;
- switch (vectorEncoding) {
+ T query = randomVector(dim);
+ switch (getVectorEncoding()) {
case BYTE:
actual =
HnswGraphSearcher.search(
- bQuery,
+ (BytesRef) query,
100,
- vectors,
- vectorEncoding,
+ (RandomAccessVectorValues<BytesRef>) vectors,
+ getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
@@ -760,17 +705,17 @@ public class TestHnswGraph extends LuceneTestCase {
case FLOAT32:
actual =
HnswGraphSearcher.search(
- query,
+ (float[]) query,
100,
- vectors,
- vectorEncoding,
+ (RandomAccessVectorValues<float[]>) vectors,
+ getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
break;
default:
- throw new IllegalArgumentException("unexpected vector encoding: " + vectorEncoding);
+ throw new IllegalArgumentException("unexpected vector encoding: " + getVectorEncoding());
}
while (actual.size() > topK) {
actual.pop();
@@ -778,10 +723,14 @@ public class TestHnswGraph extends LuceneTestCase {
NeighborQueue expected = new NeighborQueue(topK, false);
for (int j = 0; j < size; j++) {
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
- if (vectorEncoding == VectorEncoding.BYTE) {
- expected.add(j, similarityFunction.compare(bQuery, vectors.binaryValue(j)));
+ if (getVectorEncoding() == VectorEncoding.BYTE) {
+ assert query instanceof BytesRef;
+ expected.add(
+ j, similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(j)));
} else {
- expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
+ assert query instanceof float[];
+ expected.add(
+ j, similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(j)));
}
if (expected.size() > topK) {
expected.pop();
@@ -815,17 +764,16 @@ public class TestHnswGraph extends LuceneTestCase {
}
/** Returns vectors evenly distributed around the upper unit semicircle. */
- static class CircularVectorValues extends VectorValues implements RandomAccessVectorValues {
+ static class CircularVectorValues extends VectorValues
+ implements RandomAccessVectorValues<float[]> {
private final int size;
private final float[] value;
- private final BytesRef binaryValue;
int doc = -1;
CircularVectorValues(int size) {
this.size = size;
value = new float[2];
- binaryValue = new BytesRef(new byte[2]);
}
@Override
@@ -872,14 +820,70 @@ public class TestHnswGraph extends LuceneTestCase {
public float[] vectorValue(int ord) {
return unitVector2d(ord / (double) size, value);
}
+ }
+
+ /** Returns vectors evenly distributed around the upper unit semicircle. */
+ static class CircularByteVectorValues extends ByteVectorValues
+ implements RandomAccessVectorValues<BytesRef> {
+ private final int size;
+ private final float[] value;
+ private final BytesRef bValue;
+
+ int doc = -1;
+
+ CircularByteVectorValues(int size) {
+ this.size = size;
+ value = new float[2];
+ bValue = new BytesRef(new byte[2]);
+ }
@Override
- public BytesRef binaryValue(int ord) {
- float[] vectorValue = vectorValue(ord);
- for (int i = 0; i < vectorValue.length; i++) {
- binaryValue.bytes[i] = (byte) (vectorValue[i] * 127);
+ public CircularByteVectorValues copy() {
+ return new CircularByteVectorValues(size);
+ }
+
+ @Override
+ public int dimension() {
+ return 2;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public BytesRef vectorValue() {
+ return vectorValue(doc);
+ }
+
+ @Override
+ public int docID() {
+ return doc;
+ }
+
+ @Override
+ public int nextDoc() {
+ return advance(doc + 1);
+ }
+
+ @Override
+ public int advance(int target) {
+ if (target >= 0 && target < size) {
+ doc = target;
+ } else {
+ doc = NO_MORE_DOCS;
+ }
+ return doc;
+ }
+
+ @Override
+ public BytesRef vectorValue(int ord) {
+ unitVector2d(ord / (double) size, value);
+ for (int i = 0; i < value.length; i++) {
+ bValue.bytes[i] = (byte) (value[i] * 127);
}
- return binaryValue;
+ return bValue;
}
}
@@ -901,7 +905,8 @@ public class TestHnswGraph extends LuceneTestCase {
return neighbors;
}
- private void assertVectorsEqual(VectorValues u, VectorValues v) throws IOException {
+ void assertVectorsEqual(AbstractMockVectorValues<T> u, AbstractMockVectorValues<T> v)
+ throws IOException {
int uDoc, vDoc;
while (true) {
uDoc = u.nextDoc();
@@ -910,49 +915,40 @@ public class TestHnswGraph extends LuceneTestCase {
if (uDoc == NO_MORE_DOCS) {
break;
}
- float delta = vectorEncoding == VectorEncoding.BYTE ? 1 : 1e-4f;
- assertArrayEquals(
- "vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), delta);
+ switch (getVectorEncoding()) {
+ case BYTE:
+ assertArrayEquals(
+ "vectors do not match for doc=" + uDoc,
+ ((BytesRef) u.vectorValue()).bytes,
+ ((BytesRef) v.vectorValue()).bytes);
+ break;
+ case FLOAT32:
+ assertArrayEquals(
+ "vectors do not match for doc=" + uDoc,
+ (float[]) u.vectorValue(),
+ (float[]) v.vectorValue(),
+ 1e-4f);
+ break;
+ default:
+ throw new IllegalArgumentException("unknown vector encoding: " + getVectorEncoding());
+ }
}
}
- /** Produces random vectors and caches them for random-access. */
- static class RandomVectorValues extends MockVectorValues {
-
- RandomVectorValues(int size, int dimension, Random random) {
- super(createRandomVectors(size, dimension, null, random));
- }
-
- RandomVectorValues(int size, int dimension, VectorEncoding vectorEncoding, Random random) {
- super(createRandomVectors(size, dimension, vectorEncoding, random));
- }
-
- RandomVectorValues(RandomVectorValues other) {
- super(other.values);
- }
-
- @Override
- public RandomVectorValues copy() {
- return new RandomVectorValues(this);
+ static float[][] createRandomFloatVectors(int size, int dimension, Random random) {
+ float[][] vectors = new float[size][];
+ for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
+ vectors[offset] = randomVector(random, dimension);
}
+ return vectors;
+ }
- private static float[][] createRandomVectors(
- int size, int dimension, VectorEncoding vectorEncoding, Random random) {
- float[][] vectors = new float[size][];
- for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
- vectors[offset] = randomVector(random, dimension);
- }
- if (vectorEncoding == VectorEncoding.BYTE) {
- for (float[] vector : vectors) {
- if (vector != null) {
- for (int i = 0; i < vector.length; i++) {
- vector[i] = (byte) (127 * vector[i]);
- }
- }
- }
- }
- return vectors;
+ static byte[][] createRandomByteVectors(int size, int dimension, Random random) {
+ byte[][] vectors = new byte[size][];
+ for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
+ vectors[offset] = randomVector8(random, dimension);
}
+ return vectors;
}
/**
@@ -974,7 +970,7 @@ public class TestHnswGraph extends LuceneTestCase {
return bits;
}
- private static float[] randomVector(Random random, int dim) {
+ static float[] randomVector(Random random, int dim) {
float[] vec = new float[dim];
for (int i = 0; i < dim; i++) {
vec[i] = random.nextFloat();
@@ -986,11 +982,12 @@ public class TestHnswGraph extends LuceneTestCase {
return vec;
}
- private static float[] randomVector8(Random random, int dim) {
+ static byte[] randomVector8(Random random, int dim) {
float[] fvec = randomVector(random, dim);
+ byte[] bvec = new byte[dim];
for (int i = 0; i < dim; i++) {
- fvec[i] *= 127;
+ bvec[i] = (byte) (fvec[i] * 127);
}
- return fvec;
+ return bvec;
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index f11932fb787..9c210e5fe14 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -45,6 +45,7 @@ import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType;
+import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.CodecReader;
@@ -707,7 +708,17 @@ public class KnnGraphTester {
iwc.setUseCompoundFile(false);
// iwc.setMaxBufferedDocs(10000);
- FieldType fieldType = KnnVectorField.createFieldType(dim, vectorEncoding, similarityFunction);
+ final FieldType fieldType;
+ switch (vectorEncoding) {
+ case BYTE:
+ fieldType = KnnByteVectorField.createFieldType(dim, similarityFunction);
+ break;
+ default:
+ case FLOAT32:
+ fieldType = KnnVectorField.createFieldType(dim, similarityFunction);
+ break;
+ }
+ ;
if (quiet == false) {
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("creating index in " + indexPath);
@@ -722,7 +733,7 @@ public class KnnGraphTester {
switch (vectorEncoding) {
case BYTE:
doc.add(
- new KnnVectorField(
+ new KnnByteVectorField(
KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
break;
default:
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java
new file mode 100644
index 00000000000..8565cacce2c
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.hnsw;
+
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.BytesRef;
+
+class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
+ private final byte[] scratch;
+
+ static MockByteVectorValues fromValues(byte[][] byteValues) {
+ int dimension = byteValues[0].length;
+ BytesRef[] values = new BytesRef[byteValues.length];
+ for (int i = 0; i < byteValues.length; i++) {
+ values[i] = byteValues[i] == null ? null : new BytesRef(byteValues[i]);
+ }
+ BytesRef[] denseValues = new BytesRef[values.length];
+ int count = 0;
+ for (int i = 0; i < byteValues.length; i++) {
+ if (values[i] != null) {
+ denseValues[count++] = values[i];
+ }
+ }
+ return new MockByteVectorValues(values, dimension, denseValues, count);
+ }
+
+ MockByteVectorValues(BytesRef[] values, int dimension, BytesRef[] denseValues, int numVectors) {
+ super(values, dimension, denseValues, numVectors);
+ scratch = new byte[dimension];
+ }
+
+ @Override
+ public MockByteVectorValues copy() {
+ return new MockByteVectorValues(
+ ArrayUtil.copyOfSubArray(values, 0, values.length),
+ dimension,
+ ArrayUtil.copyOfSubArray(denseValues, 0, denseValues.length),
+ numVectors);
+ }
+
+ @Override
+ public BytesRef vectorValue() {
+ if (LuceneTestCase.random().nextBoolean()) {
+ return values[pos];
+ } else {
+ // Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
+ // This should help us catch cases of aliasing where the same ByteVectorValues source is used
+ // twice in a
+ // single computation.
+ System.arraycopy(values[pos].bytes, values[pos].offset, scratch, 0, dimension);
+ return new BytesRef(scratch);
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java
index 299d27857da..9eaa6163a6e 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java
@@ -17,52 +17,37 @@
package org.apache.lucene.util.hnsw;
-import org.apache.lucene.index.VectorValues;
import org.apache.lucene.tests.util.LuceneTestCase;
-import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.ArrayUtil;
-class MockVectorValues extends VectorValues implements RandomAccessVectorValues {
+class MockVectorValues extends AbstractMockVectorValues<float[]> {
private final float[] scratch;
- protected final int dimension;
- protected final float[][] denseValues;
- protected final float[][] values;
- private final int numVectors;
- private final BytesRef binaryValue;
-
- private int pos = -1;
-
- MockVectorValues(float[][] values) {
- this.dimension = values[0].length;
- this.values = values;
+ static MockVectorValues fromValues(float[][] values) {
+ int dimension = values[0].length;
int maxDoc = values.length;
- denseValues = new float[maxDoc][];
+ float[][] denseValues = new float[maxDoc][];
int count = 0;
for (int i = 0; i < maxDoc; i++) {
if (values[i] != null) {
denseValues[count++] = values[i];
}
}
- numVectors = count;
- scratch = new float[dimension];
- // used by tests that build a graph from bytes rather than floats
- binaryValue = new BytesRef(dimension);
- binaryValue.length = dimension;
+ return new MockVectorValues(values, dimension, denseValues, count);
}
- @Override
- public MockVectorValues copy() {
- return new MockVectorValues(values);
- }
-
- @Override
- public int size() {
- return numVectors;
+ MockVectorValues(float[][] values, int dimension, float[][] denseValues, int numVectors) {
+ super(values, dimension, denseValues, numVectors);
+ this.scratch = new float[dimension];
}
@Override
- public int dimension() {
- return dimension;
+ public MockVectorValues copy() {
+ return new MockVectorValues(
+ ArrayUtil.copyOfSubArray(values, 0, values.length),
+ dimension,
+ ArrayUtil.copyOfSubArray(denseValues, 0, denseValues.length),
+ numVectors);
}
@Override
@@ -83,42 +68,4 @@ class MockVectorValues extends VectorValues implements RandomAccessVectorValues
public float[] vectorValue(int targetOrd) {
return denseValues[targetOrd];
}
-
- @Override
- public BytesRef binaryValue(int targetOrd) {
- float[] value = vectorValue(targetOrd);
- for (int i = 0; i < value.length; i++) {
- binaryValue.bytes[i] = (byte) value[i];
- }
- return binaryValue;
- }
-
- private boolean seek(int target) {
- if (target >= 0 && target < values.length && values[target] != null) {
- pos = target;
- return true;
- } else {
- return false;
- }
- }
-
- @Override
- public int docID() {
- return pos;
- }
-
- @Override
- public int nextDoc() {
- return advance(pos + 1);
- }
-
- @Override
- public int advance(int target) {
- while (++pos < values.length) {
- if (seek(pos)) {
- return pos;
- }
- }
- return NO_MORE_DOCS;
- }
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java
new file mode 100644
index 00000000000..0e09ba5ce33
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.hnsw;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import com.carrotsearch.randomizedtesting.RandomizedTest;
+import java.io.IOException;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.KnnByteVectorField;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.KnnByteVectorQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.BytesRef;
+import org.junit.Before;
+
+/** Tests HNSW KNN graphs */
+public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
+
+ @Before
+ public void setup() {
+ similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
+ }
+
+ @Override
+ VectorEncoding getVectorEncoding() {
+ return VectorEncoding.BYTE;
+ }
+
+ @Override
+ Query knnQuery(String field, BytesRef vector, int k) {
+ return new KnnByteVectorQuery(field, vector, k);
+ }
+
+ @Override
+ BytesRef randomVector(int dim) {
+ return new BytesRef(randomVector8(random(), dim));
+ }
+
+ @Override
+ AbstractMockVectorValues<BytesRef> vectorValues(int size, int dimension) {
+ return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, random()));
+ }
+
+ static boolean fitsInByte(float v) {
+ return v <= 127 && v >= -128 && v % 1 == 0;
+ }
+
+ @Override
+ AbstractMockVectorValues<BytesRef> vectorValues(float[][] values) {
+ byte[][] bValues = new byte[values.length][];
+ // The case when all floats fit within a byte already.
+ boolean scaleSimple = fitsInByte(values[0][0]);
+ for (int i = 0; i < values.length; i++) {
+ bValues[i] = new byte[values[i].length];
+ for (int j = 0; j < values[i].length; j++) {
+ final float v;
+ if (scaleSimple) {
+ assert fitsInByte(values[i][j]);
+ v = values[i][j];
+ } else {
+ v = values[i][j] * 127;
+ }
+ bValues[i][j] = (byte) v;
+ }
+ }
+ return MockByteVectorValues.fromValues(bValues);
+ }
+
+ @Override
+ AbstractMockVectorValues<BytesRef> vectorValues(LeafReader reader, String fieldName)
+ throws IOException {
+ ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);
+ byte[][] vectors = new byte[reader.maxDoc()][];
+ while (vectorValues.nextDoc() != NO_MORE_DOCS) {
+ vectors[vectorValues.docID()] =
+ ArrayUtil.copyOfSubArray(
+ vectorValues.vectorValue().bytes,
+ vectorValues.vectorValue().offset,
+ vectorValues.vectorValue().offset + vectorValues.vectorValue().length);
+ }
+ return MockByteVectorValues.fromValues(vectors);
+ }
+
+ @Override
+ Field knnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
+ return new KnnByteVectorField(name, vector, similarityFunction);
+ }
+
+ @Override
+ RandomAccessVectorValues<BytesRef> circularVectorValues(int nDoc) {
+ return new CircularByteVectorValues(nDoc);
+ }
+
+ @Override
+ BytesRef getTargetVector() {
+ return new BytesRef(new byte[] {1, 0});
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java
new file mode 100644
index 00000000000..42e21c7ab76
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.hnsw;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import com.carrotsearch.randomizedtesting.RandomizedTest;
+import java.io.IOException;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.index.VectorValues;
+import org.apache.lucene.search.KnnVectorQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.FixedBitSet;
+import org.junit.Before;
+
+/** Tests HNSW KNN graphs */
+public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
+
+ @Before
+ public void setup() {
+ similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
+ }
+
+ @Override
+ VectorEncoding getVectorEncoding() {
+ return VectorEncoding.FLOAT32;
+ }
+
+ @Override
+ Query knnQuery(String field, float[] vector, int k) {
+ return new KnnVectorQuery(field, vector, k);
+ }
+
+ @Override
+ float[] randomVector(int dim) {
+ return randomVector(random(), dim);
+ }
+
+ @Override
+ AbstractMockVectorValues<float[]> vectorValues(int size, int dimension) {
+ return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, random()));
+ }
+
+ @Override
+ AbstractMockVectorValues<float[]> vectorValues(float[][] values) {
+ return MockVectorValues.fromValues(values);
+ }
+
+ @Override
+ AbstractMockVectorValues<float[]> vectorValues(LeafReader reader, String fieldName)
+ throws IOException {
+ VectorValues vectorValues = reader.getVectorValues(fieldName);
+ float[][] vectors = new float[reader.maxDoc()][];
+ while (vectorValues.nextDoc() != NO_MORE_DOCS) {
+ vectors[vectorValues.docID()] =
+ ArrayUtil.copyOfSubArray(
+ vectorValues.vectorValue(), 0, vectorValues.vectorValue().length);
+ }
+ return MockVectorValues.fromValues(vectors);
+ }
+
+ @Override
+ Field knnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
+ return new KnnVectorField(name, vector, similarityFunction);
+ }
+
+ @Override
+ RandomAccessVectorValues<float[]> circularVectorValues(int nDoc) {
+ return new CircularVectorValues(nDoc);
+ }
+
+ @Override
+ float[] getTargetVector() {
+ return new float[] {1f, 0f};
+ }
+
+ public void testSearchWithSkewedAcceptOrds() throws IOException {
+ int nDoc = 1000;
+ similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+ RandomAccessVectorValues<float[]> vectors = circularVectorValues(nDoc);
+ HnswGraphBuilder<float[]> builder =
+ HnswGraphBuilder.create(
+ vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
+ OnHeapHnswGraph hnsw = builder.build(vectors.copy());
+
+ // Skip over half of the documents that are closest to the query vector
+ FixedBitSet acceptOrds = new FixedBitSet(nDoc);
+ for (int i = 500; i < nDoc; i++) {
+ acceptOrds.set(i);
+ }
+ NeighborQueue nn =
+ HnswGraphSearcher.search(
+ getTargetVector(),
+ 10,
+ vectors.copy(),
+ getVectorEncoding(),
+ similarityFunction,
+ hnsw,
+ acceptOrds,
+ Integer.MAX_VALUE);
+
+ int[] nodes = nn.nodes();
+ assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
+ int sum = 0;
+ for (int node : nodes) {
+ assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
+ sum += node;
+ }
+ // We still expect to get reasonable recall. The lowest non-skipped docIds
+ // are closest to the query vector: sum(500,509) = 5045
+ assertTrue("sum(result docs)=" + sum, sum < 5100);
+ }
+}
diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java
index 7e71720c6aa..9720b0f286f 100644
--- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java
+++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java
@@ -20,6 +20,7 @@ import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
@@ -165,6 +166,11 @@ public class TermVectorLeafReader extends LeafReader {
return null;
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String fieldName) {
+ return null;
+ }
+
@Override
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
index 36a330829b2..828d2f556a7 100644
--- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
+++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
@@ -1396,6 +1396,11 @@ public class MemoryIndex {
return null;
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String fieldName) {
+ return null;
+ }
+
@Override
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
index dd8f2cdbaf5..ec930fd9052 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
@@ -22,6 +22,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.MergeState;
@@ -113,7 +114,9 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
@Override
public VectorValues getVectorValues(String field) throws IOException {
FieldInfo fi = fis.fieldInfo(field);
- assert fi != null && fi.getVectorDimension() > 0;
+ assert fi != null
+ && fi.getVectorDimension() > 0
+ && fi.getVectorEncoding() == VectorEncoding.FLOAT32;
VectorValues values = delegate.getVectorValues(field);
assert values != null;
assert values.docID() == -1;
@@ -122,6 +125,20 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
return values;
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ FieldInfo fi = fis.fieldInfo(field);
+ assert fi != null
+ && fi.getVectorDimension() > 0
+ && fi.getVectorEncoding() == VectorEncoding.BYTE;
+ ByteVectorValues values = delegate.getByteVectorValues(field);
+ assert values != null;
+ assert values.docID() == -1;
+ assert values.size() >= 0;
+ assert values.dimension() > 0;
+ return values;
+ }
+
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException {
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java
index 873d1eab409..5d05f9a87a0 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java
@@ -28,10 +28,12 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
+import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.StringField;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CheckIndex;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
@@ -80,7 +82,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
protected void addRandomFields(Document doc) {
switch (vectorEncoding) {
case BYTE:
- doc.add(new KnnVectorField("v2", randomVector8(30), similarityFunction));
+ doc.add(new KnnByteVectorField("v2", new BytesRef(randomVector8(30)), similarityFunction));
break;
default:
case FLOAT32:
@@ -634,9 +636,11 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
switch (fieldVectorEncodings[field]) {
case BYTE:
{
- BytesRef b = randomVector8(fieldDims[field]);
- doc.add(new KnnVectorField(fieldName, b, fieldSimilarityFunctions[field]));
- fieldTotals[field] += b.bytes[b.offset];
+ byte[] b = randomVector8(fieldDims[field]);
+ doc.add(
+ new KnnByteVectorField(
+ fieldName, new BytesRef(b), fieldSimilarityFunctions[field]));
+ fieldTotals[field] += b[0];
break;
}
default:
@@ -658,14 +662,29 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
int docCount = 0;
double checksum = 0;
String fieldName = "int" + field;
- for (LeafReaderContext ctx : r.leaves()) {
- VectorValues vectors = ctx.reader().getVectorValues(fieldName);
- if (vectors != null) {
- docCount += vectors.size();
- while (vectors.nextDoc() != NO_MORE_DOCS) {
- checksum += vectors.vectorValue()[0];
+ switch (fieldVectorEncodings[field]) {
+ case BYTE:
+ for (LeafReaderContext ctx : r.leaves()) {
+ ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues(fieldName);
+ if (byteVectorValues != null) {
+ docCount += byteVectorValues.size();
+ while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
+ checksum += byteVectorValues.vectorValue().bytes[0];
+ }
+ }
+ }
+ break;
+ default:
+ case FLOAT32:
+ for (LeafReaderContext ctx : r.leaves()) {
+ VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
+ if (vectorValues != null) {
+ docCount += vectorValues.size();
+ while (vectorValues.nextDoc() != NO_MORE_DOCS) {
+ checksum += vectorValues.vectorValue()[0];
+ }
+ }
}
- }
}
assertEquals(fieldDocCounts[field], docCount);
// Account for quantization done when indexing fields w/BYTE encoding
@@ -765,15 +784,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
LeafReader leaf = getOnlyLeafReader(reader);
StoredFields storedFields = leaf.storedFields();
- VectorValues vectorValues = leaf.getVectorValues(fieldName);
+ ByteVectorValues vectorValues = leaf.getByteVectorValues(fieldName);
assertEquals(2, vectorValues.dimension());
assertEquals(3, vectorValues.size());
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
- assertEquals(-1f, vectorValues.vectorValue()[0], 0);
+ assertEquals(-1, vectorValues.vectorValue().bytes[0], 0);
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
- assertEquals(1, vectorValues.vectorValue()[0], 0);
+ assertEquals(1, vectorValues.vectorValue().bytes[0], 0);
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
- assertEquals(0, vectorValues.vectorValue()[0], 0);
+ assertEquals(0, vectorValues.vectorValue().bytes[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
}
}
@@ -925,7 +944,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
for (int i = 0; i < numDoc; i++) {
if (random().nextInt(7) != 3) {
// usually index a vector value for a doc
- values[i] = randomVector8(dimension);
+ values[i] = new BytesRef(randomVector8(dimension));
++numValues;
}
if (random().nextBoolean() && values[i] != null) {
@@ -953,7 +972,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
try (IndexReader reader = DirectoryReader.open(iw)) {
int valueCount = 0, totalSize = 0;
for (LeafReaderContext ctx : reader.leaves()) {
- VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
+ ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
if (vectorValues == null) {
continue;
}
@@ -961,7 +980,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
StoredFields storedFields = ctx.reader().storedFields();
int docId;
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
- BytesRef v = vectorValues.binaryValue();
+ BytesRef v = vectorValues.vectorValue();
assertEquals(dimension, v.length);
String idString = storedFields.document(docId).getField("id").stringValue();
int id = Integer.parseInt(idString);
@@ -1151,7 +1170,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
throws IOException {
Document doc = new Document();
if (vector != null) {
- doc.add(new KnnVectorField(field, vector, similarityFunction));
+ doc.add(new KnnByteVectorField(field, vector, similarityFunction));
}
doc.add(new NumericDocValuesField("sortkey", sortKey));
String idString = Integer.toString(id);
@@ -1193,13 +1212,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
return v;
}
- private BytesRef randomVector8(int dim) {
+ private byte[] randomVector8(int dim) {
float[] v = randomVector(dim);
byte[] b = new byte[dim];
for (int i = 0; i < dim; i++) {
b[i] = (byte) (v[i] * 127);
}
- return new BytesRef(b);
+ return b;
}
public void testCheckIndexIncludesVectors() throws Exception {
@@ -1308,9 +1327,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
switch (vectorEncoding) {
case BYTE:
{
- BytesRef b = randomVector8(dim);
- fieldValuesCheckSum += b.bytes[b.offset];
- doc.add(new KnnVectorField("knn_vector", b, similarityFunction));
+ byte[] b = randomVector8(dim);
+ fieldValuesCheckSum += b[0];
+ doc.add(new KnnByteVectorField("knn_vector", new BytesRef(b), similarityFunction));
break;
}
case FLOAT32:
@@ -1335,17 +1354,36 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
double checksum = 0;
int docCount = 0;
long sumDocIds = 0;
- for (LeafReaderContext ctx : r.leaves()) {
- VectorValues vectors = ctx.reader().getVectorValues("knn_vector");
- if (vectors != null) {
- StoredFields storedFields = ctx.reader().storedFields();
- docCount += vectors.size();
- while (vectors.nextDoc() != NO_MORE_DOCS) {
- checksum += vectors.vectorValue()[0];
- Document doc = storedFields.document(vectors.docID(), Set.of("id"));
- sumDocIds += Integer.parseInt(doc.get("id"));
+ switch (vectorEncoding) {
+ case BYTE:
+ for (LeafReaderContext ctx : r.leaves()) {
+ ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues("knn_vector");
+ if (byteVectorValues != null) {
+ docCount += byteVectorValues.size();
+ StoredFields storedFields = ctx.reader().storedFields();
+ while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
+ checksum += byteVectorValues.vectorValue().bytes[0];
+ Document doc = storedFields.document(byteVectorValues.docID(), Set.of("id"));
+ sumDocIds += Integer.parseInt(doc.get("id"));
+ }
+ }
}
- }
+ break;
+ default:
+ case FLOAT32:
+ for (LeafReaderContext ctx : r.leaves()) {
+ VectorValues vectorValues = ctx.reader().getVectorValues("knn_vector");
+ if (vectorValues != null) {
+ docCount += vectorValues.size();
+ StoredFields storedFields = ctx.reader().storedFields();
+ while (vectorValues.nextDoc() != NO_MORE_DOCS) {
+ checksum += vectorValues.vectorValue()[0];
+ Document doc = storedFields.document(vectorValues.docID(), Set.of("id"));
+ sumDocIds += Integer.parseInt(doc.get("id"));
+ }
+ }
+ }
+ break;
}
assertEquals(
fieldValuesCheckSum,
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java
index e54721fca00..0863e2681b9 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java
@@ -25,6 +25,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
@@ -235,6 +236,11 @@ class MergeReaderWrapper extends LeafReader {
return in.getVectorValues(fieldName);
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String fieldName) throws IOException {
+ return in.getByteVectorValues(fieldName);
+ }
+
@Override
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java
index 6e01564cdcd..bc4cdd0be50 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java
@@ -24,6 +24,7 @@ import java.io.IOException;
import java.util.List;
import java.util.Random;
import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.Fields;
import org.apache.lucene.index.IndexReader;
@@ -230,6 +231,11 @@ public class QueryUtils {
return null;
}
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ return null;
+ }
+
@Override
public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {