You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by so...@apache.org on 2019/10/10 22:46:19 UTC
[lucene-solr] 01/01: LUCENE-9004: approximate nearest vector search
(WIP)
This is an automated email from the ASF dual-hosted git repository.
sokolov pushed a commit to branch LUCENE-9004
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git
commit 2aba8b89796150760fa8939452b5d9271d7b5b7b
Author: Michael Sokolov <so...@amazon.com>
AuthorDate: Mon Sep 9 18:44:22 2019 -0400
LUCENE-9004: approximate nearest vector search (WIP)
---
.gitignore | 3 +-
.../apache/lucene/codecs/DocValuesConsumer.java | 311 ++++++++++++++++-
.../codecs/perfield/PerFieldDocValuesFormat.java | 10 +-
.../org/apache/lucene/document/KnnGraphField.java | 76 +++++
.../lucene/document/ReferenceDocValuesField.java | 91 +++++
.../lucene/document/VectorDocValuesField.java | 74 ++++
.../java/org/apache/lucene/index/CheckIndex.java | 2 +-
.../apache/lucene/index/DefaultIndexingChain.java | 43 ++-
.../java/org/apache/lucene/index/DocValues.java | 2 +-
.../org/apache/lucene/index/DocsWithFieldSet.java | 16 +
.../org/apache/lucene/index/FilterLeafReader.java | 2 +-
.../org/apache/lucene/index/KnnGraphWriter.java | 85 +++++
.../lucene/index/ReferenceDocValuesWriter.java | 331 ++++++++++++++++++
.../org/apache/lucene/index/SNDVWriterBase.java | 23 ++
.../lucene/index/SortedNumericDocValuesWriter.java | 7 +-
.../org/apache/lucene/index/VectorDocValues.java | 136 ++++++++
.../apache/lucene/index/VectorDocValuesWriter.java | 276 +++++++++++++++
.../java/org/apache/lucene/search/GraphSearch.java | 337 +++++++++++++++++++
.../org/apache/lucene/index/TestDocValues.java | 37 ++
.../apache/lucene/index/TestDocValuesIndexing.java | 7 +
.../test/org/apache/lucene/index/TestKnnGraph.java | 255 ++++++++++++++
.../lucene/index/TestReferenceDocValues.java | 360 ++++++++++++++++++++
.../apache/lucene/index/TestVectorDocValues.java | 344 +++++++++++++++++++
.../org/apache/lucene/search/KnnGraphTester.java | 373 +++++++++++++++++++++
.../apache/lucene/index/memory/MemoryIndex.java | 2 +-
.../test/org/apache/solr/search/TestDocSet.java | 2 +-
26 files changed, 3176 insertions(+), 29 deletions(-)
diff --git a/.gitignore b/.gitignore
index d0f0ade..ce7efb3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,7 +7,7 @@ build
dist
lib
test-lib
-/*~
+*~
/velocity.log
/build.properties
/.idea
@@ -29,4 +29,3 @@ pom.xml
__pycache__
/dev-tools/scripts/scripts.iml
.DS_Store
-
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/DocValuesConsumer.java b/lucene/core/src/java/org/apache/lucene/codecs/DocValuesConsumer.java
index 8526be6..68f1be2 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/DocValuesConsumer.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/DocValuesConsumer.java
@@ -19,9 +19,11 @@ package org.apache.lucene.codecs;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
+import org.apache.lucene.document.ReferenceDocValuesField;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.DocValues;
@@ -32,13 +34,19 @@ import org.apache.lucene.index.FilteredTermsEnum;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.OrdinalMap;
+import org.apache.lucene.index.ReferenceDocValuesWriter;
import org.apache.lucene.index.SegmentWriteState; // javadocs
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.index.VectorDocValues;
+import org.apache.lucene.index.VectorDocValuesWriter;
+import org.apache.lucene.search.GraphSearch;
+import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.Counter;
import org.apache.lucene.util.LongBitSet;
import org.apache.lucene.util.LongValues;
import org.apache.lucene.util.packed.PackedInts;
@@ -140,7 +148,12 @@ public abstract class DocValuesConsumer implements Closeable {
} else if (type == DocValuesType.SORTED_SET) {
mergeSortedSetField(mergeFieldInfo, mergeState);
} else if (type == DocValuesType.SORTED_NUMERIC) {
- mergeSortedNumericField(mergeFieldInfo, mergeState);
+ String refType = mergeFieldInfo.getAttribute(ReferenceDocValuesField.REFTYPE_ATTR);
+ if ("knn-graph".equals(refType)) {
+ mergeReferenceField(mergeFieldInfo, mergeState);
+ } else {
+ mergeSortedNumericField(mergeFieldInfo, mergeState);
+ }
} else {
throw new AssertionError("type=" + type);
}
@@ -450,6 +463,302 @@ public abstract class DocValuesConsumer implements Closeable {
});
}
+ /**
+ * Merges the sorted docvalues from <code>toMerge</code>.
+ * <p>
+ * The default implementation calls {@link #addSortedNumericField}, passing
+ * iterables that filter deleted documents.
+ */
+ public void mergeReferenceField(FieldInfo mergeFieldInfo, final MergeState mergeState) throws IOException {
+
+ assert mergeFieldInfo.name.substring(mergeFieldInfo.name.length() - 4).equals("$nbr");
+ String vectorFieldName = mergeFieldInfo.name.substring(0, mergeFieldInfo.name.length() - 4);
+ // We must compute the entire merged field in memory since each document's values depend on its neighbors
+ //mergeState.infoStream.message("ReferenceDocValues", "merging " + mergeState.segmentInfo);
+ List<VectorDocValuesSub> subs = new ArrayList<>();
+ List<VectorDocValuesSupplier> suppliers = new ArrayList<>();
+ int dimension = -1;
+ for (int i = 0 ; i < mergeState.docValuesProducers.length; i++) {
+ DocValuesProducer docValuesProducer = mergeState.docValuesProducers[i];
+ if (docValuesProducer != null) {
+ FieldInfo vectorFieldInfo = mergeState.fieldInfos[i].fieldInfo(vectorFieldName);
+ if (vectorFieldInfo != null && vectorFieldInfo.getDocValuesType() == DocValuesType.BINARY) {
+ int segmentDimension = VectorDocValuesWriter.getDimensionFromAttribute(vectorFieldInfo);
+ if (dimension == -1) {
+ dimension = segmentDimension;
+ } else if (dimension != segmentDimension) {
+ throw new IllegalStateException("Varying dimensions for vector-valued field " + mergeFieldInfo.name
+ + ": " + dimension + "!=" + segmentDimension);
+ }
+ VectorDocValues values = VectorDocValues.get(docValuesProducer.getBinary(vectorFieldInfo), dimension);
+ suppliers.add(() -> VectorDocValues.get(docValuesProducer.getBinary(vectorFieldInfo), segmentDimension));
+ subs.add(new VectorDocValuesSub(i, mergeState.docMaps[i], values));
+ }
+ }
+ }
+ // Create a new SortedNumericDocValues by iterating over the vectors, searching for
+ // its nearest neighbor vectors in the newly merged segments' vectors, mapping the resulting
+ // docids using docMaps in the mergeState.
+ MultiVectorDV multiVectors = new MultiVectorDV(suppliers, subs, mergeState.maxDocs);
+ ReferenceDocValuesWriter refWriter = new ReferenceDocValuesWriter(mergeFieldInfo, Counter.newCounter(false));
+ SortedNumericDocValues refs = refWriter.getBufferedValues();
+ float[] vector = new float[dimension];
+ GraphSearch graphSearch = GraphSearch.fromDimension(dimension);
+ int i;
+ for (i = 0; i < subs.size(); i++) {
+ // advance past the first document; there are no neighbors for it
+ if (subs.get(i).nextDoc() != NO_MORE_DOCS) {
+ break;
+ }
+ }
+ for (; i < subs.size(); i++) {
+ VectorDocValuesSub sub = subs.get(i);
+ MergeState.DocMap docMap = mergeState.docMaps[sub.segmentIndex];
+ // nocommit: test sorted index and test index with deletions
+ int docid;
+ while ((docid = sub.nextDoc()) != NO_MORE_DOCS) {
+ int mappedDocId = docMap.get(docid);
+ assert sub.values.docID() == docid;
+ assert docid == multiVectors.unmap(mappedDocId) : "unmap mismatch " + docid + " != " + multiVectors.unmap(mappedDocId);
+ sub.values.vector(vector);
+ //System.out.println("merge doc " + mappedDocId + " mapped from [" + i + "," + docid + "] in thread " + Thread.currentThread().getName());
+ for (ScoreDoc ref : graphSearch.search(() -> multiVectors, () -> refs, vector, mappedDocId)) {
+ if (ref.doc >= 0) {
+ // ignore sentinels
+ //System.out.println(" ref " + ref.doc);
+ refWriter.addValue(mappedDocId, ref.doc);
+ }
+ }
+ }
+ }
+
+ addSortedNumericField(mergeFieldInfo,
+ new EmptyDocValuesProducer() {
+ @Override
+ public SortedNumericDocValues getSortedNumeric(FieldInfo fieldInfo) {
+ if (fieldInfo != mergeFieldInfo) {
+ throw new IllegalArgumentException("wrong FieldInfo");
+ }
+ //mergeState.infoStream.message("ReferenceDocValues", "new iterator " + mergeState.segmentInfo);
+ return refWriter.getIterableValues();
+ }
+ });
+
+ //mergeState.infoStream.message("ReferenceDocValues", " mergeReferenceField done: " + mergeState.segmentInfo);
+ }
+
+ /** Tracks state of one binary sub-reader that we are merging */
+ private static class VectorDocValuesSub extends DocIDMerger.Sub {
+
+ final VectorDocValues values;
+ final int segmentIndex;
+
+ public VectorDocValuesSub(int segmentIndex, MergeState.DocMap docMap, VectorDocValues values) {
+ super(docMap);
+ this.values = values;
+ this.segmentIndex = segmentIndex;
+ assert values.docID() == -1;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ return values.nextDoc();
+ }
+ }
+
+ // provides a view over multiple VectorDocValues by concatenating their docid spaces
+ private static class MultiVectorDV extends VectorDocValues {
+ private final VectorDocValues[] subValues;
+ private final int[] docBase;
+ private final int[] segmentMaxDocs;
+ private final int cost;
+
+ private int whichSub;
+
+ MultiVectorDV(List<VectorDocValuesSupplier> suppliers, List<VectorDocValuesSub> subs, int[] maxDocs) throws IOException {
+ this.subValues = new VectorDocValues[suppliers.size()];
+ // TODO: this complicated logic needs its own test
+ // maxDocs actually says *how many* docs there are, not what the number of the max doc is
+ int maxDoc = -1;
+ int lastMaxDoc = -1;
+ segmentMaxDocs = new int[subs.size() - 1];
+ docBase = new int[subs.size()];
+ for (int i = 0, j = 0; j < subs.size(); i++) {
+ lastMaxDoc = maxDoc;
+ maxDoc += maxDocs[i];
+ if (i == subs.get(j).segmentIndex) {
+ // we may skip some segments if they have no docs with values for this field
+ if (j > 0) {
+ segmentMaxDocs[j - 1] = lastMaxDoc;
+ }
+ docBase[j] = lastMaxDoc + 1;
+ ++j;
+ }
+ }
+
+ int i = 0;
+ int totalCost = 0;
+ for (VectorDocValuesSupplier supplier : suppliers) {
+ ResettingVectorDV sub = new ResettingVectorDV(supplier);
+ totalCost += sub.cost();
+ this.subValues[i++] = sub;
+ }
+ cost = totalCost;
+ whichSub = 0;
+ }
+
+ private int findSegment(int docid) {
+ int segment = Arrays.binarySearch(segmentMaxDocs, docid);
+ if (segment < 0) {
+ return -1 - segment;
+ } else {
+ return segment;
+ }
+ }
+
+ @Override
+ public int docID() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int dimension() {
+ return subValues[0].dimension();
+ }
+
+ @Override
+ public long cost() {
+ return cost;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ int rebased = unmapSettingWhich(target);
+ if (rebased < 0) {
+ rebased = 0;
+ }
+ int segmentDocId = subValues[whichSub].advance(rebased);
+ if (segmentDocId == NO_MORE_DOCS) {
+ if (++whichSub < subValues.length) {
+ // Get the first document in the next segment; Note that all segments have values.
+ segmentDocId = subValues[whichSub].advance(0);
+ } else {
+ return NO_MORE_DOCS;
+ }
+ }
+ return docBase[whichSub] + segmentDocId;
+ }
+
+ @Override
+ public boolean advanceExact(int target) throws IOException {
+ int rebased = unmapSettingWhich(target);
+ if (rebased < 0) {
+ return false;
+ }
+ return subValues[whichSub].advanceExact(rebased);
+ }
+
+ int unmap(int docid) {
+ // map from global (merged) to segment-local (unmerged)
+ // like mapDocid but no side effect - used for assertion
+ return docid - docBase[findSegment(docid)];
+ }
+
+ private int unmapSettingWhich(int target) {
+ whichSub = findSegment(target);
+ return target - docBase[whichSub];
+ }
+
+ @Override
+ public void vector(float[] vector) throws IOException {
+ subValues[whichSub].vector(vector);
+ }
+ }
+
+ // provides pseudo-random access to the values as float[] by recreating an underlying
+ // iterator whenever the iteration goes backwards
+ private static class ResettingVectorDV extends VectorDocValues {
+
+ private final VectorDocValuesSupplier supplier;
+ private VectorDocValues delegate;
+ private int docId = -1;
+
+ ResettingVectorDV(VectorDocValuesSupplier supplier) throws IOException {
+ this.supplier = supplier;
+ delegate = supplier.get();
+ }
+
+ @Override
+ public int docID() {
+ if (docId < 0) {
+ return -docId;
+ } else {
+ return docId;
+ }
+ }
+
+ @Override
+ public int dimension() {
+ return delegate.dimension();
+ }
+
+ @Override
+ public long cost() {
+ return delegate.cost();
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ docId = delegate.nextDoc();
+ return docId;
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ if (target == docId) {
+ return target;
+ }
+ maybeReset(target);
+ docId = delegate.advance(target);
+ return docId;
+ }
+
+ @Override
+ public boolean advanceExact(int target) throws IOException {
+ if (target == docId) {
+ return true;
+ }
+ maybeReset(target);
+ boolean advanced = delegate.advanceExact(target);
+ if (advanced) {
+ docId = delegate.docID();
+ } else {
+ docId = -delegate.docID();
+ }
+ return advanced;
+ }
+
+ @Override
+ public void vector(float[] vector) throws IOException {
+ delegate.vector(vector);
+ }
+
+ private void maybeReset(int target) throws IOException {
+ if (target < delegate.docID()) {
+ delegate = supplier.get();
+ }
+ }
+ }
+
+ private interface VectorDocValuesSupplier {
+ VectorDocValues get() throws IOException;
+ }
+
/** Tracks state of one sorted sub-reader that we are merging */
private static class SortedDocValuesSub extends DocIDMerger.Sub {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldDocValuesFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldDocValuesFormat.java
index f2e8940..6b19535 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldDocValuesFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldDocValuesFormat.java
@@ -136,13 +136,9 @@ public abstract class PerFieldDocValuesFormat extends DocValuesFormat {
// Group each consumer by the fields it handles
for (FieldInfo fi : mergeState.mergeFieldInfos) {
// merge should ignore current format for the fields being merged
- DocValuesConsumer consumer = getInstance(fi, true);
- Collection<String> fieldsForConsumer = consumersToField.get(consumer);
- if (fieldsForConsumer == null) {
- fieldsForConsumer = new ArrayList<>();
- consumersToField.put(consumer, fieldsForConsumer);
- }
- fieldsForConsumer.add(fi.name);
+ consumersToField
+ .computeIfAbsent(getInstance(fi, true), k -> new ArrayList<>())
+ .add(fi.name);
}
// Delegate the merge to the appropriate consumer
diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnGraphField.java b/lucene/core/src/java/org/apache/lucene/document/KnnGraphField.java
new file mode 100644
index 0000000..7a720c6
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/document/KnnGraphField.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.document;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.lucene.index.DocValuesType;
+import org.apache.lucene.index.VectorDocValues;
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * Field that models a graph that supports (approximate) KNN (K-nearest-neighbor) search. Each
+ * document (that is in the graph) is represented by a vector {@code float[]} value with a the
+ * specified length. <p> Here's an example usage: </p>
+ *
+ * <pre class="prettyprint">
+ * document.add(new KnnGraphField(name, new float[]{1f, 2f, 3f}));
+ * </pre>
+ *
+ */
+public class KnnGraphField extends Field {
+
+ public static final String SUBTYPE_ATTR = "subtype";
+ public static final String KNN_GRAPH = "knn-graph";
+
+ private static final ConcurrentHashMap<Integer, FieldType> TYPES= new ConcurrentHashMap<>();
+
+ private static FieldType createType(int dimension) {
+ if (dimension <= 0) {
+ throw new IllegalArgumentException("dimension must be positive, not " + dimension);
+ }
+ FieldType type = new FieldType();
+ type.setDocValuesType(DocValuesType.BINARY);
+ type.putAttribute(VectorDocValues.DIMENSION_ATTR, Integer.toString(dimension));
+ type.putAttribute(SUBTYPE_ATTR, KNN_GRAPH);
+ type.freeze();
+ return type;
+ }
+
+ /**
+ * Dimensioned types; one type per vector length
+ */
+ public static FieldType type(int dimension) {
+ return TYPES.computeIfAbsent(dimension, KnnGraphField::createType);
+ }
+
+ /**
+ * Create a new binary KnnGraph field wrapping an array of floats.
+ * @param name field name
+ * @param value vector of float values as an array
+ * @throws IllegalArgumentException if the field name is null
+ */
+ public KnnGraphField(String name, float[] value) {
+ super(name, type(value.length));
+ fieldsData = value;
+ }
+
+ public float[] vectorValue() {
+ return (float[]) fieldsData;
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/document/ReferenceDocValuesField.java b/lucene/core/src/java/org/apache/lucene/document/ReferenceDocValuesField.java
new file mode 100644
index 0000000..e2d288c
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/document/ReferenceDocValuesField.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.document;
+
+
+import org.apache.lucene.index.DocValuesType;
+
+/**
+ * <p>
+ * Field that stores per-document <code>int</code> references to documents.
+ * Here's an example usage:
+ *
+ * <pre class="prettyprint">
+ * document.add(new ReferenceDocValuesField(name, 5L));
+ * document.add(new ReferenceDocValuesField(name, 14L));
+ * </pre>
+ *
+ * @lucene.experimental
+ */
+
+public class ReferenceDocValuesField extends Field {
+
+ public static final String REFTYPE_ATTR = "reftype";
+ public static final String DOCID_ATTR_VALUE = "docid";
+ public static final String KNN_GRAPH_ATTR_VALUE = "knn-graph";
+
+ /**
+ * Type for reference DocValues. Possible extension point for other reference use cases.
+ * Not fully implemented.
+ * @lucene.experimental
+ */
+ static final FieldType TYPE = new FieldType();
+ static {
+ TYPE.setDocValuesType(DocValuesType.SORTED_NUMERIC);
+ // indicate this field holds references to docids
+ TYPE.putAttribute(REFTYPE_ATTR, DOCID_ATTR_VALUE);
+ TYPE.freeze();
+ }
+
+ /**
+ * Type for reference DocValues used in a KNN graph. References are made to be symmetric. This
+ * field will be recalculated on merge using a search for nearest neighbors in the merged segment.
+ * @lucene.experimental
+ */
+ public static final FieldType KNN_GRAPH_TYPE = new FieldType();
+ static {
+ KNN_GRAPH_TYPE.setDocValuesType(DocValuesType.SORTED_NUMERIC);
+ // indicate this field holds references to docids managed by a KNN graph
+ KNN_GRAPH_TYPE.putAttribute(REFTYPE_ATTR, KNN_GRAPH_ATTR_VALUE);
+ KNN_GRAPH_TYPE.freeze();
+ }
+
+ /**
+ * Creates a new DocValues field with the specified 64-bit long value
+ * @param name field name
+ * @param value 64-bit long value
+ * @throws IllegalArgumentException if the field name is null
+ * @lucene.experimental
+ */
+ public ReferenceDocValuesField(String name, int value) {
+ super(name, TYPE);
+ fieldsData = Integer.valueOf(value);
+ }
+
+ /**
+ * Creates a new DocValues field with the specified 64-bit long value
+ * @param name field name
+ * @param value 64-bit long value
+ * @param type the type of reference. knn-graph references do not remove references to deleted docs
+ * @throws IllegalArgumentException if the field name is null
+ * @lucene.experimental
+ */
+ public ReferenceDocValuesField(String name, int value, FieldType type) {
+ super(name, type);
+ fieldsData = Integer.valueOf(value);
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/document/VectorDocValuesField.java b/lucene/core/src/java/org/apache/lucene/document/VectorDocValuesField.java
new file mode 100644
index 0000000..43a02e4
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/document/VectorDocValuesField.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.document;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.lucene.index.DocValuesType;
+import org.apache.lucene.index.VectorDocValues;
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * Field that models a per-document {@code float[]} value with a constant length for all documents having a value.
+ * <p>
+ * Here's an example usage:
+ * </p>
+ *
+ * <pre class="prettyprint">
+ * document.add(new VectorDocValuesField(name, new float[]{1f, 2f, 3f}));
+ * </pre>
+ *
+ * @see org.apache.lucene.index.VectorDocValues
+ * */
+public class VectorDocValuesField extends Field {
+
+ private static final ConcurrentHashMap<Integer, FieldType> TYPES= new ConcurrentHashMap<>();
+
+ private static FieldType createType(int dimension) {
+ if (dimension <= 0) {
+ throw new IllegalArgumentException("dimension must be positive, not " + dimension);
+ }
+ FieldType type = new FieldType();
+ type.setDocValuesType(DocValuesType.BINARY);
+ type.putAttribute(VectorDocValues.DIMENSION_ATTR, Integer.toString(dimension));
+ type.freeze();
+ return type;
+ }
+
+ /**
+ * Dimensioned types; one type per vector length
+ */
+ public static FieldType type(int dimension) {
+ return TYPES.computeIfAbsent(dimension, VectorDocValuesField::createType);
+ }
+
+ /**
+ * Create a new binary DocValues field.
+ * @param name field name
+ * @param value vector of float values as an array
+ * @throws IllegalArgumentException if the field name is null
+ */
+ public VectorDocValuesField(String name, float[] value) {
+ super(name, type(value.length));
+ BytesRef bytesRef = new BytesRef(value.length * 4);
+ ByteBuffer buf = ByteBuffer.wrap(bytesRef.bytes, 0, value.length * 4);
+ buf.asFloatBuffer().put(value);
+ bytesRef.length = bytesRef.bytes.length;
+ fieldsData = bytesRef;
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
index ffbd7e5..4e1a1f2 100644
--- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
+++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
@@ -354,7 +354,7 @@ public final class CheckIndex implements Closeable {
/** Total number of sortednumeric fields */
public long totalSortedNumericFields;
-
+
/** Total number of sortedset fields */
public long totalSortedSetFields;
diff --git a/lucene/core/src/java/org/apache/lucene/index/DefaultIndexingChain.java b/lucene/core/src/java/org/apache/lucene/index/DefaultIndexingChain.java
index 12bfe07..35b97a4 100644
--- a/lucene/core/src/java/org/apache/lucene/index/DefaultIndexingChain.java
+++ b/lucene/core/src/java/org/apache/lucene/index/DefaultIndexingChain.java
@@ -37,6 +37,8 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsFormat;
import org.apache.lucene.codecs.PointsWriter;
import org.apache.lucene.document.FieldType;
+import org.apache.lucene.document.KnnGraphField;
+import org.apache.lucene.document.ReferenceDocValuesField;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
@@ -135,12 +137,6 @@ final class DefaultIndexingChain extends DocConsumer {
SegmentReadState readState = new SegmentReadState(state.directory, state.segmentInfo, state.fieldInfos, true, IOContext.READ, state.segmentSuffix, Collections.emptyMap());
t0 = System.nanoTime();
- writeDocValues(state, sortMap);
- if (docState.infoStream.isEnabled("IW")) {
- docState.infoStream.message("IW", ((System.nanoTime()-t0)/1000000) + " msec to write docValues");
- }
-
- t0 = System.nanoTime();
writePoints(state, sortMap);
if (docState.infoStream.isEnabled("IW")) {
docState.infoStream.message("IW", ((System.nanoTime()-t0)/1000000) + " msec to write points");
@@ -180,6 +176,12 @@ final class DefaultIndexingChain extends DocConsumer {
docState.infoStream.message("IW", ((System.nanoTime()-t0)/1000000) + " msec to write postings and finish vectors");
}
+ t0 = System.nanoTime();
+ writeDocValues(state, sortMap);
+ if (docState.infoStream.isEnabled("IW")) {
+ docState.infoStream.message("IW", ((System.nanoTime()-t0)/1000000) + " msec to write docValues");
+ }
+
// Important to save after asking consumer to flush so
// consumer can alter the FieldInfo* if necessary. EG,
// FreqProxTermsWriter does this with
@@ -602,9 +604,23 @@ final class DefaultIndexingChain extends DocConsumer {
case BINARY:
if (fp.docValuesWriter == null) {
- fp.docValuesWriter = new BinaryDocValuesWriter(fp.fieldInfo, bytesUsed);
+ if (KnnGraphField.KNN_GRAPH.equals(fp.fieldInfo.getAttribute(KnnGraphField.SUBTYPE_ATTR))) {
+ // nocommit uniquify field name - also move the naming logic into something Knn-specific, like in KnnGraphField?
+ PerField refField = getOrAddField(fp.fieldInfo.name + "$nbr", ReferenceDocValuesField.KNN_GRAPH_TYPE, false);
+ ReferenceDocValuesWriter refWriter = new ReferenceDocValuesWriter(refField.fieldInfo, bytesUsed);
+ refField.docValuesWriter = refWriter;
+ // TODO: isn't it strange that getOrAddField does not already do this?
+ refField.fieldInfo.setDocValuesType(ReferenceDocValuesField.KNN_GRAPH_TYPE.docValuesType());
+ fp.docValuesWriter = new KnnGraphWriter(fp.fieldInfo, bytesUsed, refWriter);
+ } else {
+ fp.docValuesWriter = new BinaryDocValuesWriter(fp.fieldInfo, bytesUsed);
+ }
+ }
+ if (KnnGraphField.KNN_GRAPH.equals(fp.fieldInfo.getAttribute(KnnGraphField.SUBTYPE_ATTR))) {
+ ((KnnGraphWriter) fp.docValuesWriter).addValue(docID, ((KnnGraphField) field).vectorValue());
+ } else {
+ ((BinaryDocValuesWriter) fp.docValuesWriter).addValue(docID, field.binaryValue());
}
- ((BinaryDocValuesWriter) fp.docValuesWriter).addValue(docID, field.binaryValue());
break;
case SORTED:
@@ -613,12 +629,17 @@ final class DefaultIndexingChain extends DocConsumer {
}
((SortedDocValuesWriter) fp.docValuesWriter).addValue(docID, field.binaryValue());
break;
-
+
case SORTED_NUMERIC:
if (fp.docValuesWriter == null) {
- fp.docValuesWriter = new SortedNumericDocValuesWriter(fp.fieldInfo, bytesUsed);
+ String refType = fp.fieldInfo.getAttribute(ReferenceDocValuesField.REFTYPE_ATTR);
+ if (refType != null) {
+ fp.docValuesWriter = new ReferenceDocValuesWriter(fp.fieldInfo, bytesUsed);
+ } else {
+ fp.docValuesWriter = new SortedNumericDocValuesWriter(fp.fieldInfo, bytesUsed);
+ }
}
- ((SortedNumericDocValuesWriter) fp.docValuesWriter).addValue(docID, field.numericValue().longValue());
+ ((SNDVWriterBase) fp.docValuesWriter).addValue(docID, field.numericValue().longValue());
break;
case SORTED_SET:
diff --git a/lucene/core/src/java/org/apache/lucene/index/DocValues.java b/lucene/core/src/java/org/apache/lucene/index/DocValues.java
index 63488d0..67103a0 100644
--- a/lucene/core/src/java/org/apache/lucene/index/DocValues.java
+++ b/lucene/core/src/java/org/apache/lucene/index/DocValues.java
@@ -393,7 +393,7 @@ public final class DocValues {
}
return dv;
}
-
+
/**
* Returns SortedSetDocValues for the field, or {@link #emptySortedSet} if it has none.
* @return docvalues instance, or an empty instance if {@code field} does not exist in this reader.
diff --git a/lucene/core/src/java/org/apache/lucene/index/DocsWithFieldSet.java b/lucene/core/src/java/org/apache/lucene/index/DocsWithFieldSet.java
index 6c0d6dd..dee8082 100644
--- a/lucene/core/src/java/org/apache/lucene/index/DocsWithFieldSet.java
+++ b/lucene/core/src/java/org/apache/lucene/index/DocsWithFieldSet.java
@@ -49,6 +49,22 @@ final class DocsWithFieldSet extends DocIdSet {
cost++;
}
+ boolean contains(int docID) {
+ if (set == null) {
+ return docID <= lastDocId;
+ } else {
+ return set.get(docID);
+ }
+ }
+
+ int cost() {
+ return cost;
+ }
+
+ int lastDocId() {
+ return lastDocId;
+ }
+
@Override
public long ramBytesUsed() {
return BASE_RAM_BYTES_USED + (set == null ? 0 : set.ramBytesUsed());
diff --git a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
index 1d26d17..354e594 100644
--- a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java
@@ -391,7 +391,7 @@ public abstract class FilterLeafReader extends LeafReader {
ensureOpen();
return in.getSortedDocValues(field);
}
-
+
@Override
public SortedNumericDocValues getSortedNumericDocValues(String field) throws IOException {
ensureOpen();
diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnGraphWriter.java b/lucene/core/src/java/org/apache/lucene/index/KnnGraphWriter.java
new file mode 100644
index 0000000..553fb7c
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/index/KnnGraphWriter.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+
+import java.io.IOException;
+
+import org.apache.lucene.codecs.DocValuesConsumer;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.GraphSearch;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.SortField;
+import org.apache.lucene.util.Counter;
+
+class KnnGraphWriter extends DocValuesWriter {
+
+ private final FieldInfo fieldInfo;
+ private final Counter iwBytesUsed;
+ private final VectorDocValuesWriter vectorWriter;
+ private final ReferenceDocValuesWriter refWriter;
+ private final VectorDocValues vectorValues;
+ private final SortedNumericDocValues refs;
+ private final GraphSearch graphSearch;
+
+ // in the usual Lucene sense - the maximum doc id, plus one
+ private int maxDoc;
+
+ public KnnGraphWriter(FieldInfo fieldInfo, Counter iwBytesUsed, ReferenceDocValuesWriter refWriter) {
+ this.fieldInfo = fieldInfo;
+ this.iwBytesUsed = iwBytesUsed;
+ this.refWriter = refWriter;
+ refs = refWriter.getBufferedValues();
+ vectorWriter = new VectorDocValuesWriter(fieldInfo, iwBytesUsed);
+ vectorValues = vectorWriter.getBufferedValues();
+ // nocommit magic number 6: what value of topK should we use here?
+ graphSearch = GraphSearch.fromDimension(vectorValues.dimension());
+ }
+
+ public void addValue(int docId, float[] vector) throws IOException {
+ if (maxDoc > 0) {
+ for (ScoreDoc ref : graphSearch.search(() -> vectorValues, () -> refs, vector, maxDoc)) {
+ if (ref.doc >= 0) { // there can be sentinels present
+ refWriter.addValue(docId, ref.doc);
+ }
+ }
+ }
+ assert docId >= maxDoc;
+ maxDoc = docId + 1;
+ vectorWriter.addValue(docId, vector);
+ }
+
+ @Override
+ public void flush(SegmentWriteState state, Sorter.DocMap sortMap, DocValuesConsumer dvConsumer) throws IOException {
+ vectorWriter.flush(state, sortMap, dvConsumer);
+ }
+
+ @Override
+ DocIdSetIterator getDocIdSet() {
+ return vectorWriter.getDocIdSet();
+ }
+
+ @Override
+ Sorter.DocComparator getDocComparator(int numDoc, SortField sortField) throws IOException {
+ return vectorWriter.getDocComparator(numDoc, sortField);
+ }
+
+ @Override
+ public void finish(int maxDoc) {
+ }
+
+}
diff --git a/lucene/core/src/java/org/apache/lucene/index/ReferenceDocValuesWriter.java b/lucene/core/src/java/org/apache/lucene/index/ReferenceDocValuesWriter.java
new file mode 100644
index 0000000..d494866
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/index/ReferenceDocValuesWriter.java
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.lucene.codecs.DocValuesConsumer;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.document.ReferenceDocValuesField;
+import org.apache.lucene.search.SortField;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.Counter;
+import org.apache.lucene.util.RamUsageEstimator;
+import org.apache.lucene.util.packed.PackedInts;
+import org.apache.lucene.util.packed.PackedLongValues;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.apache.lucene.search.DocIdSetIterator.all;
+
+/** <p>Buffers up pending int[] of in-segment docids for each doc, then sorts, packs and flushes when
+ * segment flushes, using the same encoding as SortedNumericDocValues. There are two flavors,
+ * selected by the {@link org.apache.lucene.document.Field} attribute {@code reftype}, either {@code docid} or {@code
+ * knn-graph}. The {@code knn-graph} flavor models a connected, undirected, acyclic graph. Each
+ * new document refers to previous documents, and links are made symmetric by this writer. So
+ * the simple star graph 1-2-3 would be encoded as:</p>
+
+ * <pre>
+ * 1: [ 2 ]
+ * 2: [ 1, 3 ]
+ * 3: [ 2 ]
+ * </pre>
+ *
+ * <p>The {@docid} flavor does not enforce symmetry, and allows forward references to documents that
+ * have not yet been inserted. It can be said to model graphs more generally, allowing for directed,
+ * disconnected, and cyclic graphs. In both cases, references to nonexistent or deleted documents
+ * are purged when flushing and/or merging.</p>
+ *
+ * nocommit this is public unlike other DocValuesWriters so we can access it from DocValuesConsumer
+ *
+ * @lucene.internal
+ */
+
+public class ReferenceDocValuesWriter extends SNDVWriterBase {
+ private final Counter iwBytesUsed;
+ private final FieldInfo fieldInfo;
+ private final boolean isGraph;
+ private SortedMap<Integer, IntArray> allValues;
+ private IntArray currentValues;
+ private DocsWithFieldSet docsWithField = new DocsWithFieldSet();
+ private int currentDocId;
+
+ public ReferenceDocValuesWriter(FieldInfo fieldInfo, Counter iwBytesUsed) {
+ this.fieldInfo = fieldInfo;
+ this.iwBytesUsed = iwBytesUsed;
+ // TODO: add some overhead for currentValues map initial size
+ allValues = new TreeMap<>();
+ currentDocId = -1;
+ String refType = fieldInfo.getAttribute(ReferenceDocValuesField.REFTYPE_ATTR);
+ isGraph = "knn-graph".equals(refType);
+ }
+
+ @Override
+ public void addValue(int docID, long value) {
+ if (value < 0 || value == docID) {
+ // nocommit: where is the right place to enforce that value <= maxDoc? When merging value > docID occurs
+ throw new IllegalArgumentException("ReferenceDocValues.addValue " + value + " is not a valid docID");
+ }
+ int ivalue = (int) value;
+ IntArray referent = allValues.get(ivalue);
+ long used;
+ if (referent == null) {
+ assert allValues.isEmpty() : "ReferenceDocValues reference to document not in graph: " + value;
+ // We must special-case the first entry. Because we require forward-only iteration (why?) we need
+ // to bootstrap the nonexistent root node
+ referent = new IntArray(docID);
+ allValues.put(ivalue, referent);
+ used = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + 4 + RamUsageEstimator.sizeOf(referent.values);
+ currentDocId = ivalue;
+ } else {
+ used = referent.add(docID);
+ }
+ if (docID != currentDocId) {
+ // during merge this is not the case
+ // assert docID > currentDocId;
+ currentValues = new IntArray(ivalue);
+ allValues.put(docID, currentValues);
+ currentDocId = docID;
+ // nocommit improve this estimate
+ used += RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + 4 + RamUsageEstimator.sizeOf(currentValues.values);
+ } else {
+ used += currentValues.add(ivalue);
+ }
+ iwBytesUsed.addAndGet(used);
+ }
+
+ boolean hasValue(int docID) {
+ return allValues.containsKey(docID);
+ }
+
+ @Override
+ public void finish(int maxDoc) {
+ }
+
+ @Override
+ Sorter.DocComparator getDocComparator(int maxDoc, SortField sortField) {
+ throw new IllegalArgumentException("It is forbidden to sort by a ReferenceDocValues field");
+ }
+
+ @Override
+ public void flush(SegmentWriteState state, Sorter.DocMap sortMap, DocValuesConsumer dvConsumer) throws IOException {
+ PackedLongValues.Builder valuesBuilder = PackedLongValues.deltaPackedBuilder(PackedInts.COMPACT); // stream of all values
+ PackedLongValues.Builder countsBuilder = PackedLongValues.deltaPackedBuilder(PackedInts.COMPACT);; // count of values per doc
+ for (Map.Entry<Integer, IntArray> e : allValues.entrySet()) {
+ IntArray values = e.getValue();
+ // record the values for this doc
+ if (sortMap != null) {
+ // If the index is sorted, remap the references
+ for (int i = 0; i < values.size; i++) {
+ values.values[i] = sortMap.oldToNew(values.values[i]);
+ }
+ }
+ // Sort the values in ascending order as required by SortedNumericDocValues format
+ values.sort();
+ int lastValue = -1;
+ long countBefore = valuesBuilder.size();
+ for (int i = 0; i < values.size; i++) {
+ int value = values.values[i];
+ if (value != lastValue) {
+ // eliminate duplicate values
+ lastValue = value;
+ if (isGraph || state.liveDocs == null || state.liveDocs.get(value)) {
+ // drop references to deleted documents, unless this is a knn-graph, when we must retain
+ // them in order to preserve the graph-connectivity
+ valuesBuilder.add(value);
+ }
+ }
+ }
+ long numAdded = valuesBuilder.size() - countBefore;
+ // record the number of values for this doc
+ if (numAdded > 0) {
+ // may be zero if deletions caused all references to be dropped
+ docsWithField.add(e.getKey());
+ countsBuilder.add(numAdded);
+ }
+ }
+
+ final PackedLongValues values = valuesBuilder.build();
+ final PackedLongValues valueCounts = countsBuilder.build();
+
+ final long[][] sorted;
+ if (sortMap != null) {
+ sorted = SortedNumericDocValuesWriter.sortDocValues(state.segmentInfo.maxDoc(), sortMap,
+ new SortedNumericDocValuesWriter.BufferedSortedNumericDocValues(values, valueCounts, docsWithField.iterator()));
+ } else {
+ sorted = null;
+ }
+
+ dvConsumer.addSortedNumericField(fieldInfo,
+ new EmptyDocValuesProducer() {
+ @Override
+ public SortedNumericDocValues getSortedNumeric(FieldInfo fieldInfoIn) {
+ if (fieldInfoIn != fieldInfo) {
+ throw new IllegalArgumentException("wrong fieldInfo");
+ }
+ final SortedNumericDocValues buf =
+ new SortedNumericDocValuesWriter.BufferedSortedNumericDocValues(values, valueCounts, docsWithField.iterator());
+ if (sorted == null) {
+ return buf;
+ } else {
+ return new SortingLeafReader.SortingSortedNumericDocValues(buf, sorted);
+ }
+ }
+ });
+ // Is this re-used??
+ allValues.clear();
+ docsWithField = null;
+ }
+
+ @Override
+ DocIdSetIterator getDocIdSet() {
+ return docsWithField.iterator();
+ }
+
+ public SortedNumericDocValues getIterableValues() {
+ return new IterableValues();
+ }
+
+ public SortedNumericDocValues getBufferedValues() {
+ return new RandomAccessValues();
+ }
+
+ private abstract class BufferedReferenceDocValues extends SortedNumericDocValues {
+ int docID = -1;
+ IntArray values;
+ int valueUpTo;
+
+ @Override
+ public int docID() {
+ return docID;
+ }
+
+ @Override
+ public int docValueCount() {
+ return values.size;
+ }
+
+ @Override
+ public long nextValue() {
+ return values.values[valueUpTo++];
+ }
+
+ @Override
+ public long cost() {
+ return docsWithField.cost();
+ }
+ }
+
+ private class IterableValues extends BufferedReferenceDocValues {
+ private final Iterator<Map.Entry<Integer, IntArray>> iterator = allValues.entrySet().iterator();
+
+ @Override
+ public int nextDoc() throws IOException {
+ if (iterator.hasNext()) {
+ Map.Entry<Integer, IntArray> entry = iterator.next();
+ docID = entry.getKey();
+ values = entry.getValue();
+ // Sort the values in ascending order as required by SortedNumericDocValues format
+ // Ideally we only create a single one of these iterators, but TODO: create a finish() method and do this
+ // when we are done creating the merged writer
+ values.sort();
+ valueUpTo = 0;
+ } else {
+ docID = NO_MORE_DOCS;
+ values = null;
+ }
+ return docID;
+ }
+
+ @Override
+ public int advance(int target) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean advanceExact(int target) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ }
+
+ private class RandomAccessValues extends BufferedReferenceDocValues {
+
+ @Override
+ public int nextDoc() throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int advance(int target) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean advanceExact(int target) throws IOException {
+ values = allValues.get(target);
+ if (values != null) {
+ docID = target;
+ valueUpTo = 0;
+ return true;
+ } else {
+ docID = NO_MORE_DOCS;
+ return false;
+ }
+ }
+
+ }
+
+ private static class IntArray {
+ int size;
+ int[] values = new int[8];
+
+ IntArray(int first) {
+ values[0] = first;
+ size = 1;
+ }
+
+ /**
+ * @param value the value to append to the array
+ * @return bytes allocated if the array was resized, or zero
+ */
+
+ int add(int value) {
+ int used;
+ if (size == values.length) {
+ values = ArrayUtil.grow(values, size + 1);
+ values[size++] = value;
+ used = (values.length - size) * 4;
+ } else {
+ used = 0;
+ }
+ values[size++] = value;
+ return used;
+ }
+
+ void sort() {
+ Arrays.sort(values, 0, size);
+ }
+ }
+
+}
+
diff --git a/lucene/core/src/java/org/apache/lucene/index/SNDVWriterBase.java b/lucene/core/src/java/org/apache/lucene/index/SNDVWriterBase.java
new file mode 100644
index 0000000..6e4fb4e
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/index/SNDVWriterBase.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+abstract class SNDVWriterBase extends DocValuesWriter {
+
+ public abstract void addValue(int docID, long value);
+
+}
diff --git a/lucene/core/src/java/org/apache/lucene/index/SortedNumericDocValuesWriter.java b/lucene/core/src/java/org/apache/lucene/index/SortedNumericDocValuesWriter.java
index bdc65cc..cec69a4 100644
--- a/lucene/core/src/java/org/apache/lucene/index/SortedNumericDocValuesWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/index/SortedNumericDocValuesWriter.java
@@ -34,7 +34,7 @@ import org.apache.lucene.util.packed.PackedLongValues;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/** Buffers up pending long[] per doc, sorts, then flushes when segment flushes. */
-class SortedNumericDocValuesWriter extends DocValuesWriter {
+class SortedNumericDocValuesWriter extends SNDVWriterBase {
private PackedLongValues.Builder pending; // stream of all values
private PackedLongValues.Builder pendingCounts; // count of values per doc
private DocsWithFieldSet docsWithField;
@@ -58,6 +58,7 @@ class SortedNumericDocValuesWriter extends DocValuesWriter {
iwBytesUsed.addAndGet(bytesUsed);
}
+ @Override
public void addValue(int docID, long value) {
assert docID >= currentDoc;
if (docID != currentDoc) {
@@ -118,7 +119,7 @@ class SortedNumericDocValuesWriter extends DocValuesWriter {
() -> SortedNumericSelector.wrap(docValues, sf.getSelector(), sf.getNumericType()));
}
- private long[][] sortDocValues(int maxDoc, Sorter.DocMap sortMap, SortedNumericDocValues oldValues) throws IOException {
+ static long[][] sortDocValues(int maxDoc, Sorter.DocMap sortMap, SortedNumericDocValues oldValues) throws IOException {
long[][] values = new long[maxDoc][];
int docID;
while ((docID = oldValues.nextDoc()) != NO_MORE_DOCS) {
@@ -170,7 +171,7 @@ class SortedNumericDocValuesWriter extends DocValuesWriter {
});
}
- private static class BufferedSortedNumericDocValues extends SortedNumericDocValues {
+ static class BufferedSortedNumericDocValues extends SortedNumericDocValues {
final PackedLongValues.Iterator valuesIter;
final PackedLongValues.Iterator valueCountsIter;
final DocIdSetIterator docsWithField;
diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorDocValues.java b/lucene/core/src/java/org/apache/lucene/index/VectorDocValues.java
new file mode 100644
index 0000000..ca5e6ad
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/index/VectorDocValues.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.FloatBuffer;
+
+import org.apache.lucene.store.ByteBufferIndexInput;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.automaton.CompiledAutomaton;
+
+/**
+ * A per-document fixed-length array of float values.
+ * TODO: genericize to support doubles and integers
+ * JEP-338 would be a nice help here.
+ */
+public abstract class VectorDocValues extends DocValuesIterator {
+
+ public static final String DIMENSION_ATTR = "dimension";
+
+ /** Sole constructor. (For invocation by subclass constructors, typically implicit.) */
+ protected VectorDocValues() {}
+
+ /**
+ * @return the number of values in each vector, which does not vary by document. This must always
+ * be greater than zero unless there are no documents with this value, in which case it is -1.
+ */
+ public abstract int dimension();
+
+ /**
+ * Provides direct access to an array of values.
+ * @param vector an array of values of size {@link #dimension()}. The array values are copied from the
+ * field values into this array. It is an error to call this method if the iterator is not
+ * positioned (ie one of the advance methods returned NO_MORE_DOCS, or false).
+ */
+ public abstract void vector(float[] vector) throws IOException;
+
+ /**
+ * Returns VectorDocValues for the field, or {@link DocValues#emptyBinary} if it has none.
+ * @return docvalues instance, or an empty instance if {@code field} does not exist in this reader.
+ * @throws IllegalStateException if {@code field} exists, but was not indexed with docvalues.
+ * @throws IllegalStateException if {@code field} has docvalues, but the type is not {@link DocValuesType#BINARY}
+ * or {@link DocValuesType#SORTED}
+ * @throws IOException if an I/O error occurs.
+ */
+ public static VectorDocValues get(LeafReader reader, String field) throws IOException {
+ return new BinaryVectorDV(reader, field);
+ }
+
+ public static VectorDocValues get(BinaryDocValues bdv, int dimension) {
+ return new BinaryVectorDV(bdv, dimension);
+ }
+
+}
+
+class BinaryVectorDV extends VectorDocValues {
+
+ final BinaryDocValues bdv;
+ final int dimension;
+
+ BinaryVectorDV(LeafReader reader, String field) throws IOException {
+ this(DocValues.getBinary(reader, field), reader.getFieldInfos().fieldInfo(field));
+ }
+
+ BinaryVectorDV(BinaryDocValues bdv, FieldInfo fieldInfo) throws IOException {
+ if (fieldInfo != null) {
+ String dimStr = fieldInfo.getAttribute(DIMENSION_ATTR);
+ if (dimStr == null) {
+ throw new IllegalStateException("DocValues type for field '" + fieldInfo.name + "' was indexed without a dimension.");
+ }
+ dimension = Integer.valueOf(dimStr);
+ assert dimension > 0;
+ } else {
+ dimension = -1;
+ }
+ this.bdv = bdv;
+ }
+
+ BinaryVectorDV(BinaryDocValues bdv, int dimension) {
+ this.bdv = bdv;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public int dimension() {
+ return dimension;
+ }
+
+ @Override
+ public void vector(float[] vector) throws IOException {
+ BytesRef b = bdv.binaryValue();
+ ByteBuffer.wrap(b.bytes, b.offset, b.length).asFloatBuffer().get(vector);
+ }
+
+ @Override
+ public boolean advanceExact(int target) throws IOException {
+ return bdv.advanceExact(target);
+ }
+
+ @Override
+ public int docID() {
+ return bdv.docID();
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ return bdv.nextDoc();
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ return bdv.advance(target);
+ }
+
+ @Override
+ public long cost() {
+ return bdv.cost();
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorDocValuesWriter.java b/lucene/core/src/java/org/apache/lucene/index/VectorDocValuesWriter.java
new file mode 100644
index 0000000..310c8df
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/index/VectorDocValuesWriter.java
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.FloatBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.lucene.codecs.DocValuesConsumer;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.SortField;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefBuilder;
+import org.apache.lucene.util.Counter;
+import org.apache.lucene.util.FixedBitSet;
+import org.apache.lucene.util.RamUsageEstimator;
+import org.apache.lucene.util.packed.PackedInts;
+import org.apache.lucene.util.packed.PackedLongValues;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+/** Writes float arrays using the same methods as BinaryDocValuesWriter while exposing its internal
+ * buffer for random access reads. */
+public class VectorDocValuesWriter extends DocValuesWriter {
+
+ /** Maximum length for a vector field. */
+ private static final int MAX_LENGTH = ArrayUtil.MAX_ARRAY_LENGTH;
+
+ // Number of float vector values per block:
+ private final static int BLOCK_SIZE = 8192;
+
+ private final List<ByteBuffer> buffers;
+ private final Counter iwBytesUsed;
+ private final PackedLongValues.Builder lengths;
+ private final FieldInfo fieldInfo;
+ private final int dimension;
+ private final int bufferCapacity; // how many documents in a buffer
+ private final DocsWithFieldSet docsWithField;
+
+ // We record all docids having vectors so we can get their index with a binary search
+ // nocommit Can we find a more efficient data structure for this map. We need a random access data
+ // structure in order to find the vector for a given docid, so we can perform KNN graph searches
+ // while adding docs. We *also* need to be able to iterate over docids with values.
+ private int[] docBufferMap;
+ private FloatBuffer currentBuffer;
+
+ private long bytesUsed;
+ private int lastDocID = -1;
+
+ public VectorDocValuesWriter(FieldInfo fieldInfo, Counter iwBytesUsed) {
+ this.fieldInfo = fieldInfo;
+ this.iwBytesUsed = iwBytesUsed;
+ dimension = getDimensionFromAttribute(fieldInfo);
+ bufferCapacity = BLOCK_SIZE / dimension / 4;
+ //System.out.println("dimension=" + dimension + " buffer capacity=" + bufferCapacity);
+ buffers = new ArrayList<>();
+ docBufferMap = new int[bufferCapacity];
+ allocateNewBuffer();
+ lengths = PackedLongValues.deltaPackedBuilder(PackedInts.COMPACT);
+ docsWithField = new DocsWithFieldSet();
+ updateBytesUsed();
+ }
+
+ public static int getDimensionFromAttribute(FieldInfo fieldInfo) {
+ String dimStr = fieldInfo.getAttribute(VectorDocValues.DIMENSION_ATTR);
+ if (dimStr == null) {
+ // TODO: make this impossible?
+ throw new IllegalStateException("DocValues type for vector field '" + fieldInfo.name + "' was indexed without a dimension.");
+ }
+ int dim = Integer.valueOf(dimStr);
+ if (dim > MAX_LENGTH) {
+ throw new IllegalArgumentException("DocValuesField \"" + fieldInfo.name + "\" has dimension " + dim + ", which exceeds the maximum: " + MAX_LENGTH);
+ }
+ return dim;
+ }
+
+ public void addValue(int docID, float[] value) {
+ if (docID <= lastDocID) {
+ throw new IllegalArgumentException("DocValuesField \"" + fieldInfo.name + "\" appears more than once in this document (only one value is allowed per field)");
+ }
+ if (value == null) {
+ throw new IllegalArgumentException("field=\"" + fieldInfo.name + "\": null value not allowed");
+ }
+ if (value.length != dimension) {
+ throw new IllegalArgumentException("DocValuesField \"" + fieldInfo.name + "\" has the wrong dimension: " + value.length + ". It must match its type whose dimension is " + dimension);
+ }
+ ensureCapacity();
+ // TODO: all the lengths are the same: can we just write out some degenerate PackedLongValues?
+ lengths.add(value.length);
+ currentBuffer.put(value);
+ updateBytesUsed();
+ docBufferMap[docsWithField.cost()] = docID;
+ docsWithField.add(docID);
+ lastDocID = docID;
+ }
+
+ private void ensureCapacity() {
+ if (currentBuffer.hasRemaining() == false) {
+ allocateNewBuffer();
+ }
+ }
+
+ private void allocateNewBuffer() {
+ buffers.add(ByteBuffer.allocate(4 * dimension * (BLOCK_SIZE / 4 / dimension)));
+ currentBuffer = buffers.get(buffers.size() - 1).asFloatBuffer();
+ docBufferMap = ArrayUtil.grow(docBufferMap, buffers.size() * bufferCapacity);
+ }
+
+ private void updateBytesUsed() {
+ final long newBytesUsed =
+ lengths.ramBytesUsed()
+ + (dimension * 4 * BLOCK_SIZE * buffers.size())
+ + docsWithField.ramBytesUsed()
+ + RamUsageEstimator.sizeOf(docBufferMap);
+ iwBytesUsed.addAndGet(newBytesUsed - bytesUsed);
+ bytesUsed = newBytesUsed;
+ }
+
+ @Override
+ public void finish(int maxDoc) {
+ }
+
+ @Override
+ Sorter.DocComparator getDocComparator(int numDoc, SortField sortField) throws IOException {
+ throw new IllegalArgumentException("It is forbidden to sort on a binary field");
+ }
+
+ private SortingLeafReader.CachedBinaryDVs sortDocValues(int maxDoc, Sorter.DocMap sortMap, BinaryDocValues oldValues) throws IOException {
+ FixedBitSet docsWithField = new FixedBitSet(maxDoc);
+ BytesRef[] values = new BytesRef[maxDoc];
+ while (true) {
+ int docID = oldValues.nextDoc();
+ if (docID == NO_MORE_DOCS) {
+ break;
+ }
+ int newDocID = sortMap.oldToNew(docID);
+ docsWithField.set(newDocID);
+ values[newDocID] = BytesRef.deepCopyOf(oldValues.binaryValue());
+ }
+ return new SortingLeafReader.CachedBinaryDVs(values, docsWithField);
+ }
+
+ @Override
+ public void flush(SegmentWriteState state, Sorter.DocMap sortMap, DocValuesConsumer dvConsumer) throws IOException {
+ final PackedLongValues lengths = this.lengths.build();
+ final SortingLeafReader.CachedBinaryDVs sorted;
+ if (sortMap != null) {
+ sorted = sortDocValues(state.segmentInfo.maxDoc(), sortMap, new BufferedBinaryDocValues());
+ } else {
+ sorted = null;
+ }
+ dvConsumer.addBinaryField(fieldInfo,
+ new EmptyDocValuesProducer() {
+ @Override
+ public BinaryDocValues getBinary(FieldInfo fieldInfoIn) {
+ if (fieldInfoIn != fieldInfo) {
+ throw new IllegalArgumentException("wrong fieldInfo");
+ }
+ if (sorted == null) {
+ return new BufferedBinaryDocValues();
+ } else {
+ return new SortingLeafReader.SortingBinaryDocValues(sorted);
+ }
+ }
+ });
+ }
+
+ public VectorDocValues getBufferedValues() {
+ // nocommit avoid the extra bytes copy by writing a BufferedVectorDocValues with a FloatBuffer wrapping the byte buffers
+ return VectorDocValues.get(new BufferedBinaryDocValues(), dimension);
+ }
+
+ // Iterates over the values as bytes in forward order. Also supports random access positioning via advanceExact.
+ private class BufferedBinaryDocValues extends BinaryDocValues {
+ final private int stride;
+ final private BytesRefBuilder value;
+
+ private int currentDocIndex;
+ private int docId;
+
+ BufferedBinaryDocValues() {
+ value = new BytesRefBuilder();
+ stride = dimension * 4;
+ value.grow(stride);
+ value.setLength(stride);
+ currentDocIndex = -1;
+ docId = -1;
+ }
+
+ @Override
+ public int docID() {
+ return docId;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ if (++currentDocIndex < docsWithField.cost()) {
+ copyVector();
+ docId = docBufferMap[currentDocIndex];
+ } else {
+ docId = NO_MORE_DOCS;
+ }
+ return docId;
+ }
+
+ private void copyVector() {
+ ByteBuffer buffer = buffers.get(currentDocIndex / bufferCapacity);
+ buffer.position((currentDocIndex % bufferCapacity) * stride);
+ buffer.get(value.bytes(), 0, stride);
+ }
+
+ @Override
+ public int advance(int target) {
+ int idx = Arrays.binarySearch(docBufferMap, 0, docsWithField.cost(), target);
+ if (idx < 0) {
+ idx = -1 - idx;
+ if (idx >= docBufferMap.length) {
+ currentDocIndex = NO_MORE_DOCS;
+ docId = NO_MORE_DOCS;
+ return NO_MORE_DOCS;
+ }
+ }
+ currentDocIndex = idx;
+ copyVector();
+ docId = docBufferMap[currentDocIndex];
+ return docId;
+ }
+
+ @Override
+ public boolean advanceExact(int target) throws IOException {
+ int idx = Arrays.binarySearch(docBufferMap, 0, docsWithField.cost(), target);
+ if (idx < 0) {
+ return false;
+ } else {
+ currentDocIndex = idx;
+ copyVector();
+ docId = target;
+ return true;
+ }
+ }
+
+ @Override
+ public long cost() {
+ return docsWithField.cost();
+ }
+
+ @Override
+ public BytesRef binaryValue() {
+ return value.get();
+ }
+ }
+
+ @Override
+ DocIdSetIterator getDocIdSet() {
+ return docsWithField.iterator();
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/search/GraphSearch.java b/lucene/core/src/java/org/apache/lucene/search/GraphSearch.java
new file mode 100644
index 0000000..da187de
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/GraphSearch.java
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.function.Supplier;
+
+import org.apache.lucene.index.DocValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.SortedNumericDocValues;
+import org.apache.lucene.index.VectorDocValues;
+import org.apache.lucene.util.PriorityQueue;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+/**
+ * A per-document array of references to other documents. nocommit Should this class be moved to
+ * o.a.l.index ? It currently exposes at least one method that is not really part of its intended
+ * public api in order to make it accessible to KnnGraphWriter.
+ */
+public class GraphSearch {
+
+ // private static final boolean VERBOSE = Boolean.parseBoolean(System.getProperty("GraphSearch.verbose", "false"));
+ public static boolean VERBOSE;
+ private final Set<Integer> visited = new HashSet<>();
+ private final IndexSearcher searcher;
+ private final String vectorField;
+ private final String neighborField;
+ private final int topK;
+
+ private float[] scratch;
+ private ScoreDocQueue queue;
+ private TreeSet<ScoreDoc> frontier;
+
+ /**
+ * @param topK how many results to return when searching, and how many nearest neighbors (fanout)
+ * to connect while indexing
+ */
+ public GraphSearch(int topK) {
+ this(null, null, null, topK);
+ }
+
+ public static GraphSearch fromDimension(int dimension) {
+ // TODO: experiment to find out how we can best set these heuristics
+ // Malkov, Ponomarenko, Logvinov, Krylov found 3*dim optimal for dim <= 20
+ // Their statements about how many iters to run while indexing amount to running a monte carlo experiment
+ // return new GraphSearch((int) (180 * (Math.log(1 + dimension / 20.0))));
+ // return new GraphSearch(3 * dimension);
+ return new GraphSearch(60);
+ }
+
+ private GraphSearch(IndexSearcher searcher, String vectorField, String neighborField, int topK) {
+ this.searcher = searcher;
+ this.vectorField = vectorField;
+ this.neighborField = neighborField;
+ this.topK = topK; // PriorityQueue could expose this, but does not
+ frontier = new TreeSet<>(GraphSearch::compareScoreDoc);
+ }
+
+ /**
+ * Find the topK nearest neighbors to target.
+ * @param topK how many results to return when searching, and how many nearest neighbors (fanout)
+ * to connect while indexing
+ * @param numProbe how many probes of the graph to perform when searching (and finding neighbors
+ * while indexing).
+ * @return a TopDocs listing the topK (approximate) nearest neighbors to target in order of
+ * increasing distance and docid.
+ */
+ public static TopDocs search(IndexSearcher searcher, String knnGraphField, int topK, int numProbe, float[] target)
+ throws IOException {
+ return new GraphSearch(searcher, knnGraphField, knnGraphField + "$nbr", topK).search(target, numProbe);
+ }
+
+ private TopDocs search(float[] target, int numProbe) throws IOException {
+ // TODO: implement a Query and let IndexSearcher/Collector handle this
+ ScoreDocQueue segmentQueues[] = new ScoreDocQueue[searcher.getIndexReader().leaves().size()];
+ for (LeafReaderContext context : searcher.getIndexReader().leaves()) {
+ LeafReader reader = context.reader();
+ if (VERBOSE) {
+ System.out.printf("[GraphSearch] segment #%d [%d docs]\n", context.ord, reader.maxDoc());
+ }
+ frontier.clear();
+ queue = new ScoreDocQueue(topK, () -> new ScoreDoc(-1, Float.MAX_VALUE), false);
+ doSearch(() -> VectorDocValues.get(reader, vectorField),
+ () -> DocValues.getSortedNumeric(reader, neighborField),
+ target, reader.maxDoc(), numProbe);
+ segmentQueues[context.ord] = queue;
+ }
+ return constructResults(segmentQueues, searcher.getIndexReader().leaves());
+ }
+
+ public interface SupplierThrowsIoe<T> {
+ T get() throws IOException;
+ }
+
+ /**
+ * Find the (approximate) nearest neighbor documents to the given target vector. Used when
+ * indexing - not intended as a public method.
+ * @param vectorsFactory Creates VectorDocValues of the documents to search
+ * @param neighborsFactory Creates SortedNumericDocValues representing the graph to search
+ * @param target the target vector
+ * @param maxDoc one more than the maximum document to search. This is used to generate seed entry points in the graph
+ * @return an Iterable of the approximately nearest docs, ordered by increasing distance from the target
+ * @throws IOException when there is an underlying exception reading the index
+ */
+ public Iterable<ScoreDoc> search(SupplierThrowsIoe<VectorDocValues> vectorsFactory,
+ SupplierThrowsIoe<SortedNumericDocValues> neighborsFactory,
+ float[] target, int maxDoc)
+ throws IOException {
+ assert maxDoc > 0;
+ queue = new ScoreDocQueue(topK, () -> new ScoreDoc(-1, Float.MAX_VALUE), false);
+ //System.out.printf("graph search maxDoc=%d\n", maxDoc);
+ // start from a set of limit (= log10(N)) documents, biased towards the lower ones
+ int numProbes = (int) Math.round(2 * (Math.log(maxDoc) + 1));
+ doSearch(vectorsFactory, neighborsFactory, target, maxDoc, numProbes);
+ return queue;
+ }
+
+ private void doSearch(SupplierThrowsIoe<VectorDocValues> vectorsFactory, SupplierThrowsIoe<SortedNumericDocValues> neighborsFactory,
+ float[] target, int maxDoc, int numProbes) throws IOException {
+ scratch = new float[target.length]; // TODO: move to constructor and require dimension to be provided there
+ visited.clear();
+ VectorDocValues vectors = vectorsFactory.get();
+ int entryDocId = maxDoc % numProbes; // pseudorandom rotation among document probe cycles as the index increases in size
+ for (int i = 0; i < numProbes; i++, entryDocId += getEntryIncrement(numProbes, maxDoc)) {
+ if (VERBOSE) {
+ System.out.printf("[GraphSearch] entryDocId #%d = %d\n", i, entryDocId);
+ }
+ entryDocId %= maxDoc;
+ int docId = vectors.advance(entryDocId);
+ if (docId == NO_MORE_DOCS || docId >= maxDoc) {
+ if (i == 0) {
+ // edge case - we advanced past the send of the segment on our first attempt; just try again from the beginning
+ docId = vectors.advance(0);
+ assert docId != NO_MORE_DOCS && docId < maxDoc;
+ } else {
+ return;
+ }
+ }
+ if (visited.contains(docId)) {
+ continue;
+ }
+ ScoreDoc front = queue.top();
+ // if docid is competitive, front will be set to <docId, d(docId, target)>
+ enqueue(docId, target, vectors, front);
+ if (front.doc != docId) {
+ // FIXME - on following segments, score is not competitive here - we need to give it a chance
+ continue;
+ }
+ if (VERBOSE) {
+ System.out.printf("[GraphSearch] i=%d doc=%d\n", i, docId);
+ }
+ while (true) {
+ VectorDocValues childVectors = vectorsFactory.get();
+ SortedNumericDocValues neighbors = neighborsFactory.get();
+ // front may have docid = -1???
+ ScoreDoc top = gather(childVectors, neighbors, target, front, maxDoc);
+ front = frontier.pollLast();
+ if (front == null || front.score > top.score) {
+ // No frontier doc is competitive
+ break;
+ }
+ }
+ }
+ }
+
+ private ScoreDoc gather(VectorDocValues vectors, SortedNumericDocValues neighbors, float[] target, ScoreDoc front, int maxDoc) throws IOException {
+ assert front.doc >= 0;
+ assert front.doc < maxDoc : "docid=" + front.doc + ", maxDoc=" + maxDoc;
+ ScoreDoc bottom = queue.top();
+ //System.out.printf(" get neighbors of %d\n", front.doc);
+ if (neighbors.advanceExact(front.doc) == false) {
+ // when merging this seems to happen? why isn't it taken care of above?
+ return bottom;
+ }
+ int n = neighbors.docValueCount();
+ assert n > 0;
+ for (int i = 0; i < n; i++) {
+ int docId = (int) neighbors.nextValue();
+ assert docId >= 0;
+ assert docId < maxDoc : "docid=" + docId + ", maxDoc=" + maxDoc;
+ if (visited.contains(docId) == false) {
+ visited.add(docId);
+ boolean hasVector = vectors.advanceExact(docId);
+ assert hasVector : "doc " + (docId) + " has no vector";
+ vectors.vector(scratch);
+ float distance = distance(scratch, target, bottom.score);
+ if (VERBOSE) {
+ System.out.printf(" traverse doc=%d dist=%f\n", docId, distance);
+ }
+ // Add competitive neighbors to the output queue FIXME this test does not capture that we must compare scores here
+ if (updateQueue(bottom, docId, distance)) {
+ // and to the frontier for further expansion, creating a new ScoreDoc since
+ // we modify the docs in the result queue
+ frontier.add(new ScoreDoc(docId, distance));
+ bottom = queue.top();
+ }
+ }
+ }
+ return bottom;
+ }
+
+ private void enqueue(int doc, float[] target, VectorDocValues vectors, ScoreDoc top) throws IOException {
+ boolean hasVector = vectors.advanceExact(doc);
+ assert hasVector : "doc " + doc + " has no vector";
+ visited.add(doc);
+ vectors.vector(scratch);
+ float score = distance(target, scratch, top.score);
+ updateQueue(top, doc, score);
+ }
+
+ private boolean updateQueue(ScoreDoc top, int doc, float distance) throws IOException {
+ //System.out.printf(" distance to %d = %f\n", doc, distance);
+ if (distance < top.score || (distance == top.score && doc < top.doc)) {
+ // If this neighbor is competitive, add it to the topK queue
+ top.score = distance;
+ // record global docid since we are merging into a global queue
+ top.doc = doc;
+ queue.updateTop();
+ return true;
+ // System.out.println(" queue " + scoreDoc.doc + " " + distance + " new min score=" + top.score);
+ } else {
+ return false;
+ }
+ }
+
+ private TopDocs constructResults(ScoreDocQueue[] queues, List<LeafReaderContext> contexts) {
+ TopDocs[] topDocs = new TopDocs[queues.length];
+ for (int i = 0; i < topDocs.length; i++) {
+ topDocs[i] = constructResults(queues[i], contexts.get(i).docBase);
+ }
+ return TopDocs.merge(topK, topDocs);
+ }
+
+ private TopDocs constructResults(ScoreDocQueue q, int docBase) {
+ int found = 0;
+ for (ScoreDoc scoreDoc : q) {
+ if (scoreDoc.doc >= 0) {
+ ++found;
+ }
+ }
+ ScoreDoc[] results = new ScoreDoc[found];
+ for (int i = found -1 ; i >= 0;) {
+ ScoreDoc scoreDoc = q.pop();
+ // skip sentinels
+ if (scoreDoc.doc != -1) {
+ scoreDoc.doc += docBase;
+ scoreDoc.score = -scoreDoc.score; // TopDocs.merge will sort in ascending score order
+ results[i--] = scoreDoc;
+ }
+ }
+ // the search is for the K nearest neighbors, so we never have more than K to return. The number
+ // found may be less than K though.
+ return new TopDocs(new TotalHits(found, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), results);
+ }
+
+ private static float distance(float[] a, float[] b, float minScore) {
+ assert a.length == b.length;
+ float total = 0;
+ for (int i = 0; i < a.length; i++) {
+ float d = a[i] - b[i];
+ total += d * d;
+ if (total > minScore) {
+ // return early since every dimension of the score is positive; it can only increase
+ // TODO: optimize by skipping this test until the queue is full of non-sentinels
+ return Float.MAX_VALUE;
+ }
+ }
+ return total;
+ }
+
+ private static int getEntryIncrement(int m, int maxDoc) {
+ return maxDoc / (m + 1);
+ }
+
+ /**
+ * Prefers docs with lower (positive) scores and lower docids
+ */
+ private static class ScoreDocQueue extends PriorityQueue<ScoreDoc> {
+ private final boolean ascending;
+
+ /**
+ * Creates a new queue with the given size and rank order
+ * @param capacity the number of elements the queue accommodates
+ * @param ascending if true, the least element is that with the least score. Conversely if false,
+ * the least element has the greatest score. In both cases, when scores are equal,
+ * a document with a higher docId is less than a document with a lower docId.
+ */
+ ScoreDocQueue(int capacity, Supplier<ScoreDoc> sentinel, boolean ascending) {
+ super(capacity, sentinel);
+ this.ascending = ascending;
+ }
+
+ @Override
+ protected boolean lessThan(ScoreDoc a, ScoreDoc b) {
+ if (a.score > b.score) {
+ return !ascending;
+ } else if (a.score < b.score) {
+ return ascending;
+ } else {
+ return a.doc > b.doc;
+ }
+ }
+ }
+
+ private static int compareScoreDoc(ScoreDoc a, ScoreDoc b) {
+ if (a.score < b.score) {
+ return 1;
+ } else if (a.score > b.score) {
+ return -1;
+ } else {
+ return b.doc - a.doc;
+ }
+ }
+}
+
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestDocValues.java b/lucene/core/src/test/org/apache/lucene/index/TestDocValues.java
index 442fe7d..7473b19 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestDocValues.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestDocValues.java
@@ -23,6 +23,7 @@ import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.NumericDocValuesField;
+import org.apache.lucene.document.ReferenceDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
@@ -256,6 +257,42 @@ public class TestDocValues extends LuceneTestCase {
dir.close();
}
+ /**
+ * field with reference docvalues
+ */
+ public void testReferenceField() throws Exception {
+ Directory dir = newDirectory();
+ IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null));
+ Document doc = new Document();
+ iw.addDocument(doc);
+ doc = new Document();
+ doc.add(new ReferenceDocValuesField("foo", 0));
+ iw.addDocument(doc);
+ DirectoryReader dr = DirectoryReader.open(iw);
+ LeafReader r = getOnlyLeafReader(dr);
+
+ // ok
+ assertNotNull(DocValues.getSortedNumeric(r, "foo"));
+
+ // errors
+ expectThrows(IllegalStateException.class, () -> {
+ DocValues.getBinary(r, "foo");
+ });
+ expectThrows(IllegalStateException.class, () -> {
+ DocValues.getNumeric(r, "foo");
+ });
+ expectThrows(IllegalStateException.class, () -> {
+ DocValues.getSorted(r, "foo");
+ });
+ expectThrows(IllegalStateException.class, () -> {
+ DocValues.getSortedSet(r, "foo");
+ });
+
+ dr.close();
+ iw.close();
+ dir.close();
+ }
+
public void testAddNullNumericDocValues() throws IOException {
Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null));
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestDocValuesIndexing.java b/lucene/core/src/test/org/apache/lucene/index/TestDocValuesIndexing.java
index a838af3..3c2c31b 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestDocValuesIndexing.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestDocValuesIndexing.java
@@ -18,6 +18,11 @@ package org.apache.lucene.index;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -30,6 +35,7 @@ import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.NumericDocValuesField;
+import org.apache.lucene.document.ReferenceDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StringField;
@@ -875,4 +881,5 @@ public class TestDocValuesIndexing extends LuceneTestCase {
w.close();
dir.close();
}
+
}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
new file mode 100644
index 0000000..b292c92
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.KnnGraphField;
+import org.apache.lucene.document.StringField;
+import org.apache.lucene.search.GraphSearch;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.LuceneTestCase;
+
+/** Tests indexing of a knn-graph by KnnGraphWriter */
+public class TestKnnGraph extends LuceneTestCase {
+
+ private static final String KNN_GRAPH_FIELD = "vector";
+ private static final String KNN_GRAPH_NBR_FIELD = "vector$nbr";
+
+ /**
+ * Basic test of creating documents in a graph
+ */
+ public void testBasic() throws Exception {
+ try (Directory dir = newDirectory();
+ IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null))) {
+ int numDoc = atLeast(10);
+ int dimension = atLeast(3);
+ float[][] values = new float[numDoc][];
+ for (int i = 0; i < numDoc; i++) {
+ if (random().nextBoolean()) {
+ values[i] = new float[dimension];
+ for (int j = 0; j < dimension; j++) {
+ values[i][j] = random().nextFloat();
+ }
+ }
+ add(iw, i, values[i]);
+ }
+ assertConsistentGraph(iw, dimension, values);
+ }
+ }
+
+ /**
+ * Verify that the graph properties are preserved when merging
+ */
+ public void testMerge() throws Exception {
+ try (Directory dir = newDirectory();
+ IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null))) {
+ int numDoc = atLeast(100);
+ int dimension = atLeast(10);
+ float[][] values = new float[numDoc][];
+ for (int i = 0; i < numDoc; i++) {
+ if (random().nextBoolean()) {
+ values[i] = new float[dimension];
+ for (int j = 0; j < dimension; j++) {
+ // FIXME why do all the distances look identical?
+ values[i][j] = random().nextFloat();
+ }
+ }
+ add(iw, i, values[i]);
+ if (random().nextInt(10) == 3) {
+ //System.out.println("commit");
+ iw.commit();
+ }
+ }
+ if (random().nextBoolean()) {
+ iw.forceMerge(1);
+ }
+ assertConsistentGraph(iw, dimension, values);
+ }
+ }
+
+ // TODO: testSorted
+ // TODO: testDeletions
+
+ /**
+ * Verify that searching does something reasonable
+ */
+ public void testSearch() throws Exception {
+ try (Directory dir = newDirectory();
+ IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null))) {
+ // Add a document for every cartesian point in an NxN square so we can
+ // easily know which are the nearest neighbors to every point. Insert by iterating
+ // using a prime number that is not a divisor of N*N so that we will hit each point once,
+ // and chosen so that points will be inserted in a deterministic
+ // but somewhat distributed pattern
+ int n = 5, stepSize = 17;
+ float[][] values = new float[n * n][];
+ int index = 0;
+ for (int i = 0; i < values.length; i++) {
+ // System.out.printf("%d: (%d, %d)\n", i, index % n, index / n);
+ values[i] = new float[]{index % n, index / n};
+ index = (index + stepSize) % (n * n);
+ add(iw, i, values[i]);
+ if (i == 13) {
+ // create 2 segments
+ iw.commit();
+ }
+ }
+ //System.out.println("");
+ // TODO: enable this randomness
+ if (random().nextBoolean()) {
+ iw.forceMerge(1);
+ }
+ assertConsistentGraph(iw, 2, values);
+ try (DirectoryReader dr = DirectoryReader.open(iw)) {
+ IndexSearcher searcher = new IndexSearcher(dr);
+ // results are ordered by distance (descending) and docid (ascending);
+ // This is the docid ordering:
+ // column major, origin at upper left
+ // 0 15 5 20 10
+ // 3 18 8 23 13
+ // 6 21 11 1 16
+ // 9 24 14 4 19
+ // 12 2 17 7 22
+
+ // For this small graph it seems we can always get exact results with 2 probes
+ assertResults(new int[]{11, 1, 8, 14, 21},
+ GraphSearch.search(searcher, KNN_GRAPH_FIELD, 5, 2, new float[]{2, 2}));
+ assertResults(new int[]{0, 3, 15, 18, 5},
+ GraphSearch.search(searcher, KNN_GRAPH_FIELD, 5, 2, new float[]{0, 0}));
+ assertResults(new int[]{15, 18, 0, 3, 5},
+ GraphSearch.search(searcher, KNN_GRAPH_FIELD, 5, 2, new float[]{0.3f, 0.8f}));
+ }
+ }
+ }
+
+ private void assertResults(int[] expected, TopDocs topDocs) {
+ assertEquals(expected.length, topDocs.scoreDocs.length);
+ for (int i = 0; i < expected.length; i++) {
+ assertEquals(expected[i], topDocs.scoreDocs[i].doc);
+ }
+ }
+
+ private void assertConsistentGraph(IndexWriter iw, int dimension, float[][] values) throws IOException {
+ float[] scratch = new float[dimension];
+ try (DirectoryReader dr = DirectoryReader.open(iw)) {
+ for (LeafReaderContext ctx: dr.leaves()) {
+ LeafReader reader = ctx.reader();
+ VectorDocValues vectorDocValues = VectorDocValues.get(reader, KNN_GRAPH_FIELD);
+ SortedNumericDocValues neighbors = DocValues.getSortedNumeric(reader, KNN_GRAPH_NBR_FIELD);
+ int[][] graph = new int[reader.maxDoc()][];
+ boolean singleNodeGraph = false;
+ int graphSize = 0;
+ for (int i = 0; i < reader.maxDoc(); i++) {
+ int id = Integer.parseInt(reader.document(i).get("id"));
+ if (values[id] == null) {
+ // documents without KnnGraphValues have no vectors or neighbors
+ assertFalse("document " + id + " was not expected to have values", vectorDocValues.advanceExact(i));
+ assertFalse(neighbors.advanceExact(i));
+ } else {
+ ++graphSize;
+ // documents with KnnGraphValues have the expected vectors
+ assertTrue("doc " + i + " has no vector value", vectorDocValues.advanceExact(i));
+ vectorDocValues.vector(scratch);
+ assertArrayEquals(values[id], scratch, 0f);
+ // We collect neighbors for analysis below
+ if (neighbors.advanceExact(i)) {
+ graph[i] = new int[neighbors.docValueCount()];
+ for (int j = 0; j < graph[i].length; j++) {
+ graph[i][j] = (int) neighbors.nextValue();
+ //System.out.println("" + i + " -> " + graph[i][j]);
+ }
+ } else {
+ // graph must have a single node
+ singleNodeGraph = true;
+ }
+ }
+ }
+ assertTrue(singleNodeGraph || graphSize != 1);
+ if (graphSize > 0) {
+ assertEquals(dimension, vectorDocValues.dimension());
+ }
+ // assert that the graph in each leaf is connected and undirected (ie links are reciprocated)
+ assertReciprocal(graph);
+ assertConnected(graph);
+ }
+ }
+ }
+
+ private void assertReciprocal(int[][] graph) {
+ // The graph is undirected: if a -> b then b -> a.
+ for (int i = 0; i < graph.length; i++) {
+ if (graph[i] != null) {
+ for (int j = 0; j < graph[i].length; j++) {
+ int k = graph[i][j];
+ assertTrue("" + i + "->" + k + " is not reciprocated", Arrays.binarySearch(graph[k], i) >= 0);
+ }
+ }
+ }
+ }
+
+ private void assertConnected(int[][] graph) {
+ // every node in the graph is reachable from every other node
+ Set<Integer> visited = new HashSet<>();
+ List<Integer> queue = new LinkedList<>();
+ int count = 0;
+ for (int[] entry : graph) {
+ if (entry != null) {
+ if (queue.isEmpty()) {
+ queue.add(entry[0]); // start from any node
+ //System.out.println("start at " + entry[0]);
+ }
+ ++count;
+ }
+ }
+ while(queue.isEmpty() == false) {
+ int i = queue.remove(0);
+ assertNotNull("expected neighbors of " + i, graph[i]);
+ visited.add(i);
+ for (int j : graph[i]) {
+ if (visited.contains(j) == false) {
+ //System.out.println(" ... " + j);
+ queue.add(j);
+ }
+ }
+ }
+ // we visited each node exactly once
+ assertEquals(count, visited.size());
+ }
+
+
+ private void add(IndexWriter iw, int id, float[] vector) throws IOException {
+ Document doc = new Document();
+ if (vector != null) {
+ doc.add(new KnnGraphField(KNN_GRAPH_FIELD, vector));
+ }
+ doc.add(new StringField("id", Integer.toString(id), Field.Store.YES));
+ //System.out.println("add " + id + " " + vector);
+ iw.addDocument(doc);
+ }
+
+}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestReferenceDocValues.java b/lucene/core/src/test/org/apache/lucene/index/TestReferenceDocValues.java
new file mode 100644
index 0000000..befa5d1
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/index/TestReferenceDocValues.java
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.NumericDocValuesField;
+import org.apache.lucene.document.ReferenceDocValuesField;
+import org.apache.lucene.document.StringField;
+import org.apache.lucene.search.Sort;
+import org.apache.lucene.search.SortField;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.LuceneTestCase;
+import org.junit.Ignore;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+/**
+ *
+ * Tests ReferenceDocValues integration into IndexWriter
+ *
+ */
+public class TestReferenceDocValues extends LuceneTestCase {
+
+ // TODO: specialize to graph (ie symmetric) / non-graph references
+ public void testSymmetricReference() throws Exception {
+ // create the graph with two nodes: 0-1
+ try (Directory d = newDirectory()) {
+ RandomIndexWriter w = new RandomIndexWriter(random(), d);
+
+ add(w);
+ add(w, 0);
+
+ try (DirectoryReader r = w.getReader()) {
+ w.close();
+ SortedNumericDocValues values = DocValues.getSortedNumeric(getOnlyLeafReader(r), "field");
+ assertEquals(0, values.nextDoc());
+ assertEquals(1, values.docValueCount());
+ assertEquals(1, values.nextValue());
+ assertEquals(1, values.nextDoc());
+ assertEquals(1, values.docValueCount());
+ assertEquals(0, values.nextValue());
+ }
+ }
+ }
+
+ public void testInvalidReferences() throws Exception {
+ try (Directory d = newDirectory();
+ RandomIndexWriter w = new RandomIndexWriter(random(), d)) {
+
+ // self-reference
+ expectThrows(IllegalArgumentException.class, () -> add(w, 0));
+
+ // forward reference
+ expectThrows(IllegalArgumentException.class, () -> add(w, 1));
+
+ // negative reference
+ expectThrows(IllegalArgumentException.class, () -> add(w, -1));
+ }
+ }
+
+ public void testSortedIndex() throws Exception {
+ // create the graph with two nodes: 0-1
+ IndexWriterConfig iwc = newIndexWriterConfig();
+ iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
+ try (Directory d = newDirectory();
+ RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc)) {
+
+ // docs inserted in reverse order of sortkey so they will get reordered
+ add(w, "sortkey", 2);
+ add(w, "sortkey", 1, 0);
+
+ try (DirectoryReader r = w.getReader()) {
+ SortedNumericDocValues values = DocValues.getSortedNumeric(getOnlyLeafReader(r), "field");
+ NumericDocValues sortkey = DocValues.getNumeric(getOnlyLeafReader(r), "sortkey");
+ // doc 0 refers to doc 1
+ assertEquals(0, values.nextDoc());
+ assertEquals(1, values.docValueCount());
+ assertEquals(1, values.nextValue());
+ // and has sortkey = 1
+ assertEquals(0, sortkey.nextDoc());
+ assertEquals(1, sortkey.longValue());
+
+ // doc 1 refers to doc 0
+ assertEquals(1, values.nextDoc());
+ assertEquals(1, values.docValueCount());
+ assertEquals(0, values.nextValue());
+ // and has sortkey = 2
+ assertEquals(1, sortkey.nextDoc());
+ assertEquals(2, sortkey.longValue());
+ }
+ }
+ }
+
+ @Ignore("only for asymmetric graph")
+ public void testFlushDeletes() throws Exception {
+ // create the graph with two nodes: 0-1
+ // delete one document and verify that its reference is deleted
+ try (Directory d = newDirectory()) {
+ IndexWriterConfig iwc = newIndexWriterConfig();
+ iwc.setSoftDeletesField(null);
+ RandomIndexWriter w = new RandomIndexWriter(random(), d);
+
+ add(w, "id", "a");
+ add(w, "id", "b", 0);
+ w.deleteDocuments(new Term("id", "a"));
+
+ try (DirectoryReader r = w.getReader()) {
+ assertEquals(1, w.getDocStats().numDocs);
+ w.close();
+ SortedNumericDocValues values = DocValues.getSortedNumeric(getOnlyLeafReader(r), "field");
+ // The deleted docs' values are still there!
+ assertEquals(0, values.nextDoc());
+ assertEquals(1, values.nextValue());
+ // But references to it have been dropped
+ assertEquals(NO_MORE_DOCS, values.nextDoc());
+ }
+ }
+ }
+
+ @Ignore("only for asymmetric graph")
+ public void testMerge() throws Exception {
+ // create two segments and merge, creating a disconnected graph
+ try (Directory d = newDirectory();
+ RandomIndexWriter w = new RandomIndexWriter(random(), d)) {
+
+ add(w, "id", "a");
+ add(w, "id", "b", 0);
+ w.commit();
+
+ add(w, "id", "c");
+ add(w, "id", "d", 0);
+ w.forceMerge(1);
+
+ try (DirectoryReader r = w.getReader()) {
+ assertEquals(4, w.getDocStats().maxDoc);
+ assertEquals(4, w.getDocStats().numDocs);
+ SortedNumericDocValues values = DocValues.getSortedNumeric(getOnlyLeafReader(r), "field");
+ // extract one ref per doc
+ int docid;
+ Map<String, String> refs = new HashMap<>();
+ while ((docid = values.nextDoc()) != NO_MORE_DOCS) {
+ assertEquals(1, values.docValueCount());
+ String id = r.document(0).get("id");
+ refs.put(id, r.document(docid).get("id"));
+ }
+ assertEquals(4, refs.size());
+ assertEquals("a", refs.get("b"));
+ assertEquals("b", refs.get("a"));
+ assertEquals("c", refs.get("d"));
+ assertEquals("d", refs.get("c"));
+ }
+ }
+ }
+
+ @Ignore("only for asymmetric graph")
+ public void testMergeDeletes() throws Exception {
+ // create two segments, delete a document and merge. Verify that the reference to the deleted doc is dropped
+ try (Directory d = newDirectory();
+ RandomIndexWriter w = new RandomIndexWriter(random(), d)) {
+
+ add(w, "id", "a");
+ add(w, "id", "b", 0);
+ w.commit();
+
+ add(w, "id", "c");
+ add(w, "id", "d", 0);
+ w.deleteDocuments(new Term("id", "a"));
+ w.forceMerge(1);
+
+ try (DirectoryReader r = w.getReader()) {
+ // after force merge, the deleted document is really gone
+ assertEquals(3, w.getDocStats().maxDoc);
+ assertEquals(3, w.getDocStats().numDocs);
+ SortedNumericDocValues values = DocValues.getSortedNumeric(getOnlyLeafReader(r), "field");
+ int docid;
+ Map<String, String> refs = new HashMap<>();
+ while ((docid = values.nextDoc()) != NO_MORE_DOCS) {
+ assertEquals(1, values.docValueCount());
+ String id = r.document(0).get("id");
+ refs.put(id, r.document(docid).get("id"));
+ }
+ assertEquals(2, refs.size());
+ assertNull(refs.get("b"));
+ assertEquals("c", refs.get("d"));
+ assertEquals("d", refs.get("c"));
+ }
+ }
+ }
+
+ @SuppressWarnings("try")
+ public void testRandomGraph() throws Exception {
+ int numDocs = atLeast(1000) + 2;
+ List<Doc> expected = new ArrayList<>();
+ IndexWriterConfig iwc = new IndexWriterConfig();
+ // TODO: deletions; how about soft deletes?
+ boolean sortedIndex = random().nextBoolean();
+ if (sortedIndex) {
+ iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
+ }
+ // we must write all documents that refer to each other in a single segment.
+ // This restriction means this ReferenceDocValues API cannot be public; it is only for the use of
+ // other classes that are aware of the segment state such as other DocValues writers.
+ iwc.setMaxBufferedDocs(numDocs + 1);
+ iwc.setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH);
+ try (Directory d = newDirectory();
+ IndexWriter w = new IndexWriter(d, iwc)) {
+ for (int i = 0; i < numDocs; i++) {
+ Document doc = new Document();
+ Doc expectedDoc = new Doc(i, random().nextInt(1000));
+ expected.add(expectedDoc);
+ Set<Integer> refs = new HashSet<>();
+ if (i % 2 == 0 && i > 0) {
+ // Only even-numbered docs are in the graph, and every even-numbered doc is linked to at least one other
+ int numRefs = 1 + random().nextInt(16); // more than 8 sometimes so we exercise growing an internal array
+ for (int j = 0; j < numRefs; j++) {
+ int value = random().nextInt(i / 2) * 2;
+ refs.add(value);
+ // also expect the symmetric ref
+ expected.get(value).add(i);
+ doc.add(new ReferenceDocValuesField("field", value));
+ }
+ expectedDoc.refs = new ArrayList<>(refs);
+ }
+ doc.add(new NumericDocValuesField("sortkey", expectedDoc.sortKey));
+ // add the doc
+ w.addDocument(doc);
+ }
+ if (sortedIndex) {
+ Collections.sort(expected);
+ int[] docmap = new int[expected.size()];
+ int i = 0;
+ for (Doc doc : expected) {
+ docmap[doc.id] = i++;
+ }
+ for (Doc doc : expected) {
+ // map the expected refs to their new ids
+ doc.refs = doc.refs.stream().map(ref -> docmap[ref]).collect(Collectors.toList());
+ }
+ }
+ for (Doc doc : expected) {
+ Collections.sort(doc.refs);
+ }
+ /*
+ for (int i = 0; i < numDocs && i < 20; i++){
+ System.out.println(expected.get(i));
+ }
+ */
+ try (DirectoryReader r = w.getReader()) {
+ w.close();
+ SortedNumericDocValues values = DocValues.getSortedNumeric(getOnlyLeafReader(r), "field");
+ NumericDocValues sortkey = DocValues.getNumeric(getOnlyLeafReader(r), "sortkey");
+ for (int i = 0; i < numDocs; i++) {
+ assertEquals(i, sortkey.nextDoc());
+ long sortValue = sortkey.longValue();
+ Doc expectedDoc = expected.get(i);
+ assertEquals(expectedDoc.sortKey, sortValue);
+ int originalId = expectedDoc.id;
+ List<Integer> actual = new ArrayList<>();
+ if (originalId % 2 == 0) {
+ assertEquals(i, values.nextDoc());
+ for (int j = 0; j < values.docValueCount(); j++) {
+ actual.add((int) values.nextValue());
+ }
+ }
+ assertEquals("values for doc " + i, expectedDoc.refs, actual);
+ }
+ assertEquals(NO_MORE_DOCS, values.nextDoc());
+ }
+ }
+ }
+
+ private void add(RandomIndexWriter iw, int... refs) throws IOException {
+ Document doc = new Document();
+ for (int ref : refs) {
+ doc.add(new ReferenceDocValuesField("field", ref));
+ }
+ iw.addDocument(doc);
+ }
+
+ private void add(RandomIndexWriter iw, String field, String value, int... refs) throws IOException {
+ Document doc = new Document();
+ for (int ref : refs) {
+ doc.add(new ReferenceDocValuesField("field", ref));
+ }
+ doc.add(new StringField(field, value, Field.Store.YES));
+ iw.addDocument(doc);
+ }
+
+ private void add(RandomIndexWriter iw, String field, int value, int... refs) throws IOException {
+ Document doc = new Document();
+ for (int ref : refs) {
+ doc.add(new ReferenceDocValuesField("field", ref));
+ }
+ doc.add(new NumericDocValuesField(field, value));
+ iw.addDocument(doc);
+ }
+
+ /**
+ * Mock document for recording expected state of index
+ */
+ static class Doc implements Comparable<Doc> {
+ int id;
+ int sortKey;
+ List<Integer> refs = new ArrayList<>();
+
+ Doc(int id, int sortKey) {
+ this.id = id;
+ this.sortKey = sortKey;
+ }
+
+ void add(int ref) {
+ if (refs.size() == 0 || refs.get(refs.size() - 1) != ref) {
+ // values will be added in nondecreasing order; keep them unique
+ refs.add(ref);
+ }
+ }
+
+ @Override
+ public int compareTo(Doc other) {
+ int cmp = sortKey - other.sortKey;
+ if (cmp == 0) {
+ return id - other.id;
+ } else {
+ return cmp;
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "doc id=" + id + " sortkey=" + sortKey + " refs=" + refs;
+ }
+ }
+
+}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestVectorDocValues.java b/lucene/core/src/test/org/apache/lucene/index/TestVectorDocValues.java
new file mode 100644
index 0000000..1cdab9e
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/index/TestVectorDocValues.java
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+
+import java.io.IOException;
+import java.nio.BufferUnderflowException;
+import java.nio.FloatBuffer;
+import java.util.Locale;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.VectorDocValuesField;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.FSDirectory;
+import org.apache.lucene.util.LuceneTestCase;
+
+import org.junit.Ignore;
+
+/** Tests VectorDocValues */
+public class TestVectorDocValues extends LuceneTestCase {
+
+ /**
+ * Basic test of creating indexing and retrieving instances of vector doc values.
+ */
+ public void testVectorField() throws Exception {
+ try (Directory dir = newDirectory();
+ IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null))) {
+ add(iw, 1, 2, 3);
+ try (DirectoryReader dr = DirectoryReader.open(iw)) {
+ LeafReader r = getOnlyLeafReader(dr);
+
+ // ok
+ assertNotNull(DocValues.getBinary(r, "foo"));
+ assertNotNull(VectorDocValues.get(r, "foo"));
+
+ // errors
+ expectThrows(IllegalStateException.class, () -> {
+ DocValues.getNumeric(r, "foo");
+ });
+ expectThrows(IllegalStateException.class, () -> {
+ DocValues.getSorted(r, "foo");
+ });
+ expectThrows(IllegalStateException.class, () -> {
+ DocValues.getSortedSet(r, "foo");
+ });
+ }
+ }
+ }
+
+ public void testDimensions() throws Exception {
+ try (Directory dir = newDirectory();
+ IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null))) {
+ add(iw, 1, 2, 3);
+ iw.addDocument(new Document());
+ add(iw, -1, 0, 1);
+ add(iw, 0);
+ add(iw, 0, 0, 0, 0);
+ float[] vector = new float[3];
+ try (DirectoryReader dr = DirectoryReader.open(iw)) {
+ LeafReader r = getOnlyLeafReader(dr);
+ VectorDocValues values = VectorDocValues.get(r, "foo");
+ assertEquals(0, values.nextDoc());
+ assertEquals(3, values.dimension());
+ values.vector(vector);
+ assertArrayEquals(new float[]{1, 2, 3}, vector, 0);
+ // we skip doc 1, which had no values
+ assertEquals(2, values.nextDoc());
+ assertEquals(3, values.dimension());
+ values.vector(vector);
+ assertArrayEquals(new float[]{-1, 0, 1}, vector, 0);
+ values.nextDoc();
+ expectThrows(BufferUnderflowException.class, () -> values.vector(vector));
+ // We ignore extra dimensions, but we should not
+ values.nextDoc();
+ values.vector(vector);
+ assertArrayEquals(new float[]{0, 0, 0}, vector, 0);
+ }
+ }
+ }
+
+ /*
+ private void assertArrayEquals(float[] expected, float[] actual, float delta) {
+ assertEquals("lengths differ", expected.length, actual.length);
+ for (int i = 0; i < expected.length; i++) {
+ assertEquals("mismatch at index " + i, actual[i], expected[i], delta);
+ }
+ }*/
+
+ public void testAdvance() throws Exception {
+ // We override advanceExact() and advance() to decode values and set the next value state: make sure that works
+ try (Directory dir = newDirectory();
+ IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null))) {
+ add(iw, 1, 2, 3);
+ iw.addDocument(new Document());
+ iw.addDocument(new Document());
+ add(iw, -1, 0, 1);
+ iw.addDocument(new Document());
+ float[] vector = new float[3];
+ try (DirectoryReader dr = DirectoryReader.open(iw)) {
+ LeafReader r = getOnlyLeafReader(dr);
+ VectorDocValues values = VectorDocValues.get(r, "foo");
+ // in the initial state we do no checking
+ // expectThrows(ArrayIndexOutOfBoundsException.class, () -> values.vector(vector));
+ // no values, but dimension is known
+ assertEquals(3, values.dimension());
+ // after successful advanceExact
+ assertTrue(values.advanceExact(0));
+ values.vector(vector);
+ assertArrayEquals(new float[]{1, 2, 3}, vector, 0);
+ // after unsuccessful advanceExact we do not check
+ // assertFalse(values.advanceExact(1));
+ // expectThrows(AssertionError.class, () -> values.vector(vector));
+ // after successful advance
+ assertEquals(3, values.advance(2));
+ values.vector(vector);
+ assertArrayEquals(new float[]{-1, 0, 1}, vector, 0);
+ // after unsuccessful advance (no more docs)
+ // assertEquals(DocValuesIterator.NO_MORE_DOCS, values.advance(4));
+ // expectThrows(ArrayIndexOutOfBoundsException.class, () -> values.vector(vector));
+ }
+ }
+ }
+
+ public void testZeroLengthVector() throws Exception {
+ // We disallow 0-length vector values
+ expectThrows(IllegalArgumentException.class, () -> new VectorDocValuesField("foo", new float[0]));
+ }
+
+ private void add(IndexWriter iw, float ... values) throws IOException {
+ Document doc = new Document();
+ doc.add(new VectorDocValuesField("foo", values));
+ iw.addDocument(doc);
+ }
+
+ // TODO: test performance of writing/reading and compare with an implementation that uses memory-mapped I/O to directly map an array of floats
+ // TODO: implement vector-based matching using HNSW
+
+ /*
+
+ testPerf iterate values using nextValue()
+
+ [junit4] 1> 100000 docs, dim=10; write time 219ms, read time 51ms
+ [junit4] 1> 100000 docs, dim=10; write time 130ms, read time 56ms
+ [junit4] 1> 100000 docs, dim=10; write time 139ms, read time 40ms
+ [junit4] 1> 100000 docs, dim=100; write time 1016ms, read time 44ms
+ [junit4] 1> 100000 docs, dim=100; write time 982ms, read time 130ms
+ [junit4] 1> 100000 docs, dim=100; write time 994ms, read time 46ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8449ms, read time 254ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8796ms, read time 238ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8923ms, read time 230ms
+
+ testPerf direct access to vector()
+
+ [junit4] 1> 100000 docs, dim=10; write time 142ms, read time 48ms
+ [junit4] 1> 100000 docs, dim=10; write time 158ms, read time 44ms
+ [junit4] 1> 100000 docs, dim=10; write time 133ms, read time 40ms
+ [junit4] 1> 100000 docs, dim=100; write time 909ms, read time 37ms
+ [junit4] 1> 100000 docs, dim=100; write time 893ms, read time 37ms
+ [junit4] 1> 100000 docs, dim=100; write time 898ms, read time 39ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8696ms, read time 154ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8689ms, read time 153ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8641ms, read time 149ms
+
+ testPerf provide vector() using direct access to IndexInput and its ByteBuffer
+ to avoid copying bytes, still copying floats into array
+
+ [junit4] 1> 100000 docs, dim=10; write time 177ms, read time 45ms
+ [junit4] 1> 100000 docs, dim=10; write time 151ms, read time 36ms
+ [junit4] 1> 100000 docs, dim=10; write time 158ms, read time 29ms
+ [junit4] 1> 100000 docs, dim=100; write time 917ms, read time 41ms
+ [junit4] 1> 100000 docs, dim=100; write time 943ms, read time 39ms
+ [junit4] 1> 100000 docs, dim=100; write time 927ms, read time 39ms
+ [junit4] 1> 100000 docs, dim=1000; write time 9182ms, read time 131ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8855ms, read time 141ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8847ms, read time 128ms
+
+ testPerf provide vectorBuffer() using direct access to IndexInput and its ByteBuffer
+ repeated calls to VectorBuffer.get(int) ...
+ [junit4] 1> 100000 docs, dim=10; write time 168ms, read time 53ms
+ [junit4] 1> 100000 docs, dim=10; write time 177ms, read time 32ms
+ [junit4] 1> 100000 docs, dim=10; write time 158ms, read time 38ms
+ [junit4] 1> 100000 docs, dim=100; write time 947ms, read time 48ms
+ [junit4] 1> 100000 docs, dim=100; write time 918ms, read time 44ms
+ [junit4] 1> 100000 docs, dim=100; write time 915ms, read time 42ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8894ms, read time 202ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8792ms, read time 199ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8941ms, read time 200ms
+
+ testPerf SlicedVectorValues (plumbing through IndexInput, iterating and converting from int bits to float)
+ [junit4] 1> 100000 docs, dim=10; write time 165ms, read time 44ms
+ [junit4] 1> 100000 docs, dim=10; write time 138ms, read time 42ms
+ [junit4] 1> 100000 docs, dim=10; write time 132ms, read time 37ms
+ [junit4] 1> 100000 docs, dim=100; write time 931ms, read time 59ms
+ [junit4] 1> 100000 docs, dim=100; write time 940ms, read time 61ms
+ [junit4] 1> 100000 docs, dim=100; write time 926ms, read time 60ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8977ms, read time 624ms
+ [junit4] 1> 100000 docs, dim=1000; write time 9003ms, read time 623ms
+ [junit4] 1> 100000 docs, dim=1000; write time 8958ms, read time 624ms
+
+ testRawPerf (Just an in-memory float array, not Lucene; as a best-case)
+
+ [junit4] 1> 100000 docs, dim=10; write time 39ms, read time 0ms
+ [junit4] 1> 100000 docs, dim=10; write time 41ms, read time 0ms
+ [junit4] 1> 100000 docs, dim=10; write time 39ms, read time 0ms
+ [junit4] 1> 100000 docs, dim=100; write time 392ms, read time 7ms
+ [junit4] 1> 100000 docs, dim=100; write time 392ms, read time 7ms
+ [junit4] 1> 100000 docs, dim=100; write time 392ms, read time 6ms
+ [junit4] 1> 100000 docs, dim=1000; write time 3767ms, read time 70ms
+ [junit4] 1> 100000 docs, dim=1000; write time 3770ms, read time 70ms
+ [junit4] 1> 100000 docs, dim=1000; write time 3769ms, read time 70ms
+
+ */
+
+ @Ignore
+ public void testPerf() throws Exception {
+ // Write lots of vectors, then read them back
+ // int numDocs = atLeast(100_000);
+ int numDocs = 100_000;
+ // int iters = random().nextInt(5);
+ int iters = 4;
+ //perfTest(2, 5, 1);
+ perfTest(numDocs, 10, iters);
+ perfTest(numDocs, 100, iters);
+ perfTest(numDocs, 1000, iters);
+ }
+
+ private void perfTest(int numDocs, int dimension, int iters) throws IOException {
+ String field = "field";
+ for (int iter = 0; iter < iters; iter++) {
+ try (Directory dir = FSDirectory.open(createTempDir());
+ IndexWriter iw = new IndexWriter(dir, new IndexWriterConfig())) {
+ float[] vector = new float[dimension];
+ float[] sum = new float[dimension];
+ long tStart = System.nanoTime();
+ for (int i = 0; i < numDocs; i++) {
+ for (int j = 0; j < dimension; j++) {
+ vector[j] = random().nextFloat();
+ sum[j] += vector[j];
+ }
+ /*
+ for (int j = 0; j < dimension; j++) {
+ System.out.print(vector[j] + " ");
+ }
+ System.out.println(" @" + i);
+ */
+ Document doc = new Document();
+ doc.add(new VectorDocValuesField(field, vector));
+ iw.addDocument(doc);
+ }
+ long tWrite = System.nanoTime();
+ for (int j = 0; j < dimension; j++) {
+ vector[j] = 0;
+ }
+ try (DirectoryReader dr = DirectoryReader.open(iw)) {
+ for (LeafReaderContext lrc : dr.leaves()) {
+ VectorDocValues vdv = VectorDocValues.get(lrc.reader(), field);
+ //VectorDocValues vdv = VectorDocValues.getSliced(lrc.reader(), field);
+ float[] values = new float[dimension];
+ while (vdv.nextDoc() != DocValuesIterator.NO_MORE_DOCS) {
+ vdv.vector(values);
+ for (int i = 0; i < dimension; i++) {
+ vector[i] += values[i];
+ }
+ /*
+ FloatBuffer buf = vdv.vectorBuffer();
+ int pos = buf.position();
+ for (int i = 0, j = pos; i < dimension; i++, j++) {
+ // System.out.print(buf.get(j) + " ");
+ vector[i] += buf.get(j);
+ }
+ */
+ /*
+ System.out.println("");
+ */
+ }
+ }
+ }
+ if (iter != 0) {
+ long tRead = System.nanoTime();
+ System.out.printf(Locale.ROOT,
+ "%d docs, dim=%d; write time %dms, read time %dms\n",
+ numDocs, dimension,
+ nsToMs(tWrite - tStart),
+ nsToMs(tRead - tWrite));
+ }
+ for (int i = 0; i < vector.length; i++) {
+ assertEquals("wrong value for dimension " + i, sum[i], vector[i], 0);
+ }
+ }
+ }
+ }
+
+ /*
+ public void testRawPerf() throws Exception {
+ int numDocs = 100000;
+ int iters = 4;
+ doRawPerf(numDocs, 10, iters);
+ doRawPerf(numDocs, 100, iters);
+ doRawPerf(numDocs, 1000, iters);
+ }
+ */
+
+ public void doRawPerf(int numDocs, int dimension, int iters) throws Exception {
+ float[] buffer = new float[dimension * numDocs];
+ for (int iter = 0; iter < iters; iter++) {
+ long tStart = System.nanoTime();
+ for (int i = 0; i < numDocs * dimension; i++) {
+ buffer[i++] = random().nextFloat();
+ }
+ long tWrite = System.nanoTime();
+ float[] vector = new float[dimension];
+ for (int i = 0; i < numDocs; i++) {
+ for (int j = 0, k = i * dimension; j < dimension; j++, k++) {
+ vector[j] += buffer[k];
+ }
+ }
+ if (iter != 0) {
+ long tRead = System.nanoTime();
+ System.out.printf("%d docs, dim=%d; write time %dms, read time %dms\n",
+ numDocs, dimension,
+ nsToMs(tWrite - tStart),
+ nsToMs(tRead - tWrite));
+ }
+ }
+ }
+
+ private static long nsToMs(long ns) {
+ return ns / 1_000_000;
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/search/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/search/KnnGraphTester.java
new file mode 100644
index 0000000..67dade5
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/search/KnnGraphTester.java
@@ -0,0 +1,373 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.search
+
+import java.io.IOException;
+import java.io.BufferedOutputStream;
+import java.io.BufferedInputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.KnnGraphField;
+import org.apache.lucene.document.StringField;
+import org.apache.lucene.document.StoredField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.DocValues;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.SortedNumericDocValues;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.FSDirectory;
+import org.apache.lucene.util.PriorityQueue;
+
+/** Tests indexing of a knn-graph by KnnGraphWriter */
+public class KnnGraphTester {
+
+ private final static String KNN_FIELD = "knn";
+ private final static String ID_FIELD = "id";
+
+ private Random random;
+ private int numDocs;
+ private int dim;
+ private int topK;
+ private int numProbe;
+ private float[] vectors;
+ private int[] nabors;
+
+ KnnGraphTester() {
+ // set defaults
+ numDocs = 10_000;
+ dim = 256;
+ topK = 10;
+ numProbe = 20;
+ random = new Random();
+ }
+
+ public static void main(String... args) throws Exception {
+ if (args.length != 2) {
+ usage();
+ }
+ switch (args[0]) {
+ case "-generate":
+ new KnnGraphTester().create(args[1]);
+ break;
+ case "-search":
+ new KnnGraphTester().search(args[1]);
+ break;
+ case "-stats":
+ new KnnGraphTester().stats(args[1]);
+ break;
+ default:
+ usage();
+ }
+ }
+
+ private void search(String dataFile) throws IOException {
+ readDataFile(dataFile);
+ Path indexPath = Paths.get("knn_test_index");
+ createIndex(indexPath);
+ // topK = 25;
+ testSearch(indexPath, 1000);
+ //GraphSearch.VERBOSE = true;
+ //testSearch(indexPath, 1);
+ }
+
+ private void stats(String dataFile) throws IOException {
+ readDataFile(dataFile);
+ Path indexPath = Paths.get("knn_test_index");
+ createIndex(indexPath);
+ printFanoutHist(indexPath);
+ }
+
+ private void printFanoutHist(Path indexPath) throws IOException {
+ try (Directory dir = FSDirectory.open(indexPath);
+ DirectoryReader reader = DirectoryReader.open(dir)) {
+ int maxFanout = 0;
+ int[] globalHist = new int[reader.maxDoc()];
+ for (LeafReaderContext context : reader.leaves()) {
+ LeafReader leafReader = context.reader();
+ SortedNumericDocValues nbr = DocValues.getSortedNumeric(leafReader, KNN_FIELD + "$nbr");
+ int leafMaxFanout = 0;
+ int[] leafHist = new int[leafReader.maxDoc()];
+ while(nbr.nextDoc() != SortedNumericDocValues.NO_MORE_DOCS) {
+ int n = nbr.docValueCount();
+ ++leafHist[n];
+ leafMaxFanout = Math.max(leafMaxFanout, n);
+ ++globalHist[n];
+ maxFanout = Math.max(maxFanout, n);
+ }
+ System.out.printf("Segment %d fanout\n", context.ord);
+ printHist(leafHist, leafMaxFanout);
+ }
+ System.out.println("Whole index fanout");
+ printHist(globalHist, maxFanout);
+ }
+ }
+
+ private void printHist(int[] hist, int max) {
+ System.out.printf("max fanout=%d, count[max]=%d\n", max, hist[max]);
+ int i = 0;
+ while (i <= max) {
+ int ii = i;
+ for (int j=0; j < 25 && ii <= max; ii++, j++) {
+ System.out.printf("%4d", ii);
+ }
+ System.out.println("");
+ for (int j=0; j < 25 && i <= max; i++, j++) {
+ System.out.printf("%4d", hist[i]);
+ }
+ System.out.println("");
+ }
+ System.out.println("");
+ }
+
+ private void testSearch(Path indexPath, int numIters) throws IOException {
+ float[][] targets = new float[numIters][];
+ TopDocs[] results = new TopDocs[numIters];
+ for (int i = 0; i < numIters; i++) {
+ targets[i] = new float[dim];
+ randomVector(targets[i]);
+ }
+ System.out.println("running " + numIters + " targets; topK=" + topK + ", numProbe=" + numProbe);
+ long start = System.nanoTime();
+ try (Directory dir = FSDirectory.open(indexPath);
+ DirectoryReader reader = DirectoryReader.open(dir)) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ int result = 0;
+ for (int i = 0; i < numIters; i++) {
+ results[i] = GraphSearch.search(searcher, KNN_FIELD, topK, numProbe, targets[i]);
+ for (ScoreDoc scoreDoc : results[i].scoreDocs) {
+ int id = searcher.doc(scoreDoc.doc).getFields().get(0).numericValue().intValue();
+ scoreDoc.doc = id;
+ }
+ }
+ }
+ long elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms
+ System.out.println("completed " + numIters + " searches in " + elapsed + " ms: " + (1000 * numIters / elapsed) + " QPS");
+ System.out.println("checking results");
+ checkResults(targets, results);
+ }
+
+ private void checkResults(float[][] targets, TopDocs[] results) {
+ int[] expected = new int[topK];
+ int totalMatches = 0;
+ for (int i = 0; i < results.length; i++) {
+ if (results[i].scoreDocs.length != topK) {
+ System.err.println("search " + i + " got " + results[i].scoreDocs.length + " results, expecting " + topK);
+ }
+ getActualNN(targets[i], 0, expected, 0);
+ int matched = compareNN(expected, results[i]);
+ totalMatches += matched;
+ }
+ System.out.println("total matches = " + totalMatches);
+ System.out.println("Average overlap = " + (100.0 * totalMatches / (results.length * topK)) + "%");
+ }
+
+ int compareNN(int[] expected, TopDocs results) {
+ int matched = 0;
+ int i = 0;
+ /*
+ System.out.print("expected=");
+ for (int j = 0; j < expected.length; j++) {
+ System.out.print(expected[j]);
+ System.out.print(", ");
+ }
+ System.out.print('\n');
+ System.out.println("results=");
+ for (int j = 0; j < results.scoreDocs.length; j++) {
+ System.out.print("" + results.scoreDocs[j].doc + ":" + results.scoreDocs[j].score + ", ");
+ }
+ System.out.print('\n');
+ */
+ Set<Integer> expectedSet = new HashSet<>();
+ for (int doc : expected) {
+ expectedSet.add(doc);
+ }
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ if (expectedSet.contains(scoreDoc.doc)) {
+ ++matched;
+ }
+ }
+ return matched;
+ }
+
+ void getActualNN(float[] target, int targetOffset, int[] nn, int nnOffset) {
+ final ScoreDocQueue queue = new ScoreDocQueue(topK);
+ assert queue.size() == topK : " queue.size()=" + queue.size();
+ int j = 0;
+ int vectorOffset = 0;
+ ScoreDoc bottom = queue.top();
+ while (j < numDocs) {
+ if (target == vectors && targetOffset == vectorOffset) {
+ continue;
+ }
+ float d = distance(target, targetOffset, vectorOffset, bottom.score);
+ if (d < bottom.score) {
+ bottom.doc = j;
+ bottom.score = d;
+ bottom = queue.updateTop();
+ bottom = queue.top();
+ }
+ vectorOffset += dim;
+ ++j;
+ }
+ assert queue.size() == topK;
+ nnOffset += topK;
+ for (int k = 1; k <= topK; k++) {
+ ScoreDoc scoreDoc = queue.pop();
+ //System.out.println("" + scoreDoc.doc + ":" + scoreDoc.score);
+ nn[nnOffset - k] = scoreDoc.doc;
+ }
+ }
+
+ private void randomVector(float[] vector) {
+ for(int i =0; i < vector.length; i++) {
+ vector[i] = random.nextFloat();
+ }
+ }
+
+ private void createIndex(Path indexPath) throws IOException {
+ IndexWriterConfig iwc = new IndexWriterConfig()
+ .setOpenMode(IndexWriterConfig.OpenMode.CREATE);
+ System.out.println("creating index in " + indexPath);
+ long start = System.nanoTime();
+ try (FSDirectory dir = FSDirectory.open(indexPath);
+ IndexWriter iw = new IndexWriter(dir, iwc)) {
+ for (int i = 0; i < numDocs; i++) {
+ float[] vector = new float[dim];
+ System.arraycopy(vectors, i * dim, vector, 0, dim);
+ Document doc = new Document();
+ doc.add(new KnnGraphField(KNN_FIELD, vector));
+ doc.add(new StoredField(ID_FIELD, i));
+ iw.addDocument(doc);
+ }
+ }
+ long elapsed = System.nanoTime() - start;
+ System.out.println("Indexed " + numDocs + " documents in " + elapsed / 1_000_000 + "ms");
+ }
+
+ private void create(String dataFile) throws IOException {
+ generateRandomVectors(dim * numDocs);
+ System.out.println("Generated " + numDocs + " random vectors");
+ computeNearest();
+ writeDataFile(dataFile);
+ }
+
+ private void readDataFile(String dataFile) throws IOException {
+ try (InputStream in = Files.newInputStream(Paths.get(dataFile));
+ BufferedInputStream bin = new BufferedInputStream(in);
+ DataInputStream din = new DataInputStream(bin)) {
+ numDocs = din.readInt();
+ dim = din.readInt();
+ topK = din.readInt();
+ vectors = new float[numDocs * dim];
+ for (int i = 0; i < vectors.length; i++) {
+ vectors[i] = din.readFloat();
+ }
+ nabors = new int[numDocs * topK];
+ for (int i = 0; i < nabors.length; i++) {
+ nabors[i] = din.readInt();
+ }
+ }
+ }
+
+ private void writeDataFile(String dataFile) throws IOException {
+ try (OutputStream out = Files.newOutputStream(Paths.get(dataFile));
+ BufferedOutputStream bout = new BufferedOutputStream(out);
+ DataOutputStream dout = new DataOutputStream(bout)) {
+ dout.writeInt(numDocs);
+ dout.writeInt(dim);
+ dout.writeInt(topK);
+ for (int i = 0; i < vectors.length; i++) {
+ dout.writeFloat(vectors[i]);
+ }
+ for (int i = 0; i < nabors.length; i++) {
+ dout.writeInt(nabors[i]);
+ }
+ }
+ }
+
+ private void generateRandomVectors(int size) {
+ System.out.println("Allocating " + size * 4 / 1024 / 1024 + "MB");
+ vectors = new float[size];
+ randomVector(vectors);
+ }
+
+ private void computeNearest() {
+ nabors = new int[topK * numDocs];
+ System.out.println("finding nearest...");
+ for (int i = 0; i < numDocs; i++) {
+ if (i % 100 == 1) {
+ System.out.println(" " + (i - 1));
+ }
+ getActualNN(vectors, i * dim, nabors, i * topK);
+ }
+ }
+
+ private float distance(float[] target, int targetOffset, int vectorOffset, float scoreToBeat) {
+ float total = 0;
+ for (int i = 0; i < dim; i++) {
+ float d = target[targetOffset++] - vectors[vectorOffset++];
+ total += d * d;
+ if (total > scoreToBeat) {
+ // return early since every dimension of the score is positive; it can only increase
+ return Float.MAX_VALUE;
+ }
+ }
+ return total;
+ }
+
+ private static void usage() {
+ String error = "Usage: TestKnnGraph -generate|-search {datafile}";
+ System.err.println(error);
+ System.exit(1);
+ }
+
+ private static class ScoreDocQueue extends PriorityQueue<ScoreDoc> {
+ ScoreDocQueue(int size) {
+ super(size, () -> new ScoreDoc(-1, Float.MAX_VALUE));
+ }
+
+ @Override
+ protected boolean lessThan(ScoreDoc a, ScoreDoc b) {
+ if (a.score > b.score) {
+ return true;
+ } else if (a.score < b.score) {
+ return false;
+ } else {
+ return a.doc > b.doc;
+ }
+ }
+ }
+
+}
diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
index ca9ccd7..a3b801f 100644
--- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
+++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
@@ -1220,7 +1220,7 @@ public class MemoryIndex {
return null;
}
}
-
+
@Override
public SortedSetDocValues getSortedSetDocValues(String field) {
Info info = getInfoForExpectedDocValuesType(field, DocValuesType.SORTED_SET);
diff --git a/solr/core/src/test/org/apache/solr/search/TestDocSet.java b/solr/core/src/test/org/apache/solr/search/TestDocSet.java
index 20879ea..7ac47b7 100644
--- a/solr/core/src/test/org/apache/solr/search/TestDocSet.java
+++ b/solr/core/src/test/org/apache/solr/search/TestDocSet.java
@@ -427,7 +427,7 @@ public class TestDocSet extends SolrTestCase {
public SortedNumericDocValues getSortedNumericDocValues(String field) {
return null;
}
-
+
@Override
public SortedSetDocValues getSortedSetDocValues(String field) {
return null;