You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by za...@apache.org on 2022/05/12 05:26:40 UTC

[lucene] branch main updated: LUCENE-10411: Add NN vectors support to ExitableDirectoryReader (#833)

This is an automated email from the ASF dual-hosted git repository.

zacharymorn pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new 96036bca9f6 LUCENE-10411: Add NN vectors support to ExitableDirectoryReader (#833)
96036bca9f6 is described below

commit 96036bca9f667edbdc528bfe95eeb2795526e9fa
Author: zacharymorn <za...@yahoo.com>
AuthorDate: Wed May 11 22:26:35 2022 -0700

    LUCENE-10411: Add NN vectors support to ExitableDirectoryReader (#833)
---
 lucene/CHANGES.txt                                 |   3 +
 .../lucene/index/ExitableDirectoryReader.java      | 118 ++++++++++++++++++++-
 .../apache/lucene/index/FilterVectorValues.java    |  75 +++++++++++++
 .../lucene/index/TestExitableDirectoryReader.java  |  97 +++++++++++++++++
 4 files changed, 292 insertions(+), 1 deletion(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index bebc267e39b..4556b546cee 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -141,6 +141,9 @@ Optimizations
 * LUCENE-8836: Speed up calls to TermsEnum#lookupOrd on doc values terms enums
   and sequences of increasing ords. (Bruno Roustant, Adrien Grand)
 
+* LUCENE-10411: Add nearest neighbors vectors support to ExitableDirectoryReader. 
+  (Zach Chen, Adrien Grand, Julie Tibshirani, Tomoko Uchida)
+
 * LUCENE-10542: FieldSource exists implementations can avoid value retrieval (Kevin Risden)
 
 * LUCENE-10534: MinFloatFunction / MaxFloatFunction exists check can be slow (Kevin Risden)
diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
index dfdd4e6a66f..9f5ae09b141 100644
--- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
@@ -20,6 +20,8 @@ import java.io.IOException;
 import org.apache.lucene.index.FilterLeafReader.FilterTerms;
 import org.apache.lucene.index.FilterLeafReader.FilterTermsEnum;
 import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.automaton.CompiledAutomaton;
 
@@ -323,6 +325,62 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
           : sortedSetDocValues;
     }
 
+    @Override
+    public VectorValues getVectorValues(String field) throws IOException {
+      final VectorValues vectorValues = in.getVectorValues(field);
+      if (vectorValues == null) {
+        return null;
+      }
+      return (queryTimeout.isTimeoutEnabled())
+          ? new ExitableVectorValues(vectorValues)
+          : vectorValues;
+    }
+
+    @Override
+    public TopDocs searchNearestVectors(
+        String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
+
+      // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would
+      // match all docs to allow timeout checking.
+      final Bits updatedAcceptDocs =
+          acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs;
+
+      Bits timeoutCheckingAcceptDocs =
+          new Bits() {
+            private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
+            private int calls;
+
+            @Override
+            public boolean get(int index) {
+              if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
+                checkAndThrowForSearchVectors();
+              }
+
+              return updatedAcceptDocs.get(index);
+            }
+
+            @Override
+            public int length() {
+              return updatedAcceptDocs.length();
+            }
+          };
+
+      return in.searchNearestVectors(field, target, k, timeoutCheckingAcceptDocs, visitedLimit);
+    }
+
+    private void checkAndThrowForSearchVectors() {
+      if (queryTimeout.shouldExit()) {
+        throw new ExitingReaderException(
+            "The request took too long to search nearest vectors. Timeout: "
+                + queryTimeout.toString()
+                + ", Reader="
+                + in);
+      } else if (Thread.interrupted()) {
+        throw new ExitingReaderException(
+            "Interrupted while searching nearest vectors. Reader=" + in);
+      }
+    }
+
     /**
      * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
      * if {@link Thread#interrupted()} returns true.
@@ -338,7 +396,65 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
                 + in);
       } else if (Thread.interrupted()) {
         throw new ExitingReaderException(
-            "Interrupted while iterating over point values. PointValues=" + in);
+            "Interrupted while iterating over doc values. DocValues=" + in);
+      }
+    }
+
+    private class ExitableVectorValues extends FilterVectorValues {
+      private int docToCheck;
+
+      public ExitableVectorValues(VectorValues vectorValues) {
+        super(vectorValues);
+        docToCheck = 0;
+      }
+
+      @Override
+      public int advance(int target) throws IOException {
+        final int advance = super.advance(target);
+        if (advance >= docToCheck) {
+          checkAndThrow();
+          docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
+        }
+        return advance;
+      }
+
+      @Override
+      public int nextDoc() throws IOException {
+        final int nextDoc = super.nextDoc();
+        if (nextDoc >= docToCheck) {
+          checkAndThrow();
+          docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
+        }
+        return nextDoc;
+      }
+
+      @Override
+      public float[] vectorValue() throws IOException {
+        checkAndThrow();
+        return in.vectorValue();
+      }
+
+      @Override
+      public BytesRef binaryValue() throws IOException {
+        checkAndThrow();
+        return in.binaryValue();
+      }
+
+      /**
+       * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
+       * if {@link Thread#interrupted()} returns true.
+       */
+      private void checkAndThrow() {
+        if (queryTimeout.shouldExit()) {
+          throw new ExitingReaderException(
+              "The request took too long to iterate over vector values. Timeout: "
+                  + queryTimeout.toString()
+                  + ", VectorValues="
+                  + in);
+        } else if (Thread.interrupted()) {
+          throw new ExitingReaderException(
+              "Interrupted while iterating over vector values. VectorValues=" + in);
+        }
       }
     }
   }
diff --git a/lucene/core/src/java/org/apache/lucene/index/FilterVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FilterVectorValues.java
new file mode 100644
index 00000000000..32dfc37a0f1
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/index/FilterVectorValues.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.index;
+
+import java.io.IOException;
+import java.util.Objects;
+import org.apache.lucene.util.BytesRef;
+
+/** Delegates all methods to a wrapped {@link VectorValues}. */
+public abstract class FilterVectorValues extends VectorValues {
+
+  /** Wrapped values */
+  protected final VectorValues in;
+
+  /** Sole constructor */
+  protected FilterVectorValues(VectorValues in) {
+    Objects.requireNonNull(in);
+    this.in = in;
+  }
+
+  @Override
+  public int docID() {
+    return in.docID();
+  }
+
+  @Override
+  public int nextDoc() throws IOException {
+    return in.nextDoc();
+  }
+
+  @Override
+  public int advance(int target) throws IOException {
+    return in.advance(target);
+  }
+
+  @Override
+  public long cost() {
+    return in.cost();
+  }
+
+  @Override
+  public int dimension() {
+    return in.dimension();
+  }
+
+  @Override
+  public int size() {
+    return in.size();
+  }
+
+  @Override
+  public float[] vectorValue() throws IOException {
+    return in.vectorValue();
+  }
+
+  @Override
+  public BytesRef binaryValue() throws IOException {
+    return in.binaryValue();
+  }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java
index 84c466e4aed..37d7c55d7d8 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java
@@ -16,6 +16,8 @@
  */
 package org.apache.lucene.index;
 
+import static com.carrotsearch.randomizedtesting.RandomizedTest.atMost;
+
 import java.io.IOException;
 import java.util.Arrays;
 import org.apache.lucene.document.*;
@@ -428,6 +430,101 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
     directory.close();
   }
 
+  public void testVectorValues() throws IOException {
+    Directory directory = newDirectory();
+    IndexWriter writer =
+        new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
+
+    int numDoc = atLeast(20);
+    int deletedDoc = atMost(5);
+    int dimension = atLeast(3);
+
+    for (int i = 0; i < numDoc; i++) {
+      Document doc = new Document();
+
+      float[] value = new float[dimension];
+      for (int j = 0; j < dimension; j++) {
+        value[j] = random().nextFloat();
+      }
+      FieldType fieldType =
+          KnnVectorField.createFieldType(dimension, VectorSimilarityFunction.COSINE);
+      doc.add(new KnnVectorField("vector", value, fieldType));
+
+      doc.add(new StringField("id", Integer.toString(i), Field.Store.YES));
+      writer.addDocument(doc);
+    }
+
+    writer.forceMerge(1);
+    writer.commit();
+
+    for (int i = 0; i < deletedDoc; i++) {
+      writer.deleteDocuments(new Term("id", Integer.toString(i)));
+    }
+
+    writer.close();
+
+    QueryTimeout queryTimeout;
+    if (random().nextBoolean()) {
+      if (random().nextBoolean()) {
+        queryTimeout = immediateQueryTimeout();
+      } else {
+        queryTimeout = infiniteQueryTimeout();
+      }
+    } else {
+      queryTimeout = disabledQueryTimeout();
+    }
+
+    DirectoryReader directoryReader = DirectoryReader.open(directory);
+    DirectoryReader exitableDirectoryReader =
+        new ExitableDirectoryReader(directoryReader, queryTimeout);
+    IndexReader reader = new TestReader(getOnlyLeafReader(exitableDirectoryReader));
+
+    LeafReaderContext context = reader.leaves().get(0);
+    LeafReader leaf = context.reader();
+
+    if (queryTimeout.shouldExit()) {
+      expectThrows(
+          ExitingReaderException.class,
+          () -> {
+            DocIdSetIterator iter = leaf.getVectorValues("vector");
+            scanAndRetrieve(leaf, iter);
+          });
+
+      expectThrows(
+          ExitingReaderException.class,
+          () ->
+              leaf.searchNearestVectors(
+                  "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE));
+    } else {
+      DocIdSetIterator iter = leaf.getVectorValues("vector");
+      scanAndRetrieve(leaf, iter);
+
+      leaf.searchNearestVectors(
+          "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE);
+    }
+
+    reader.close();
+    directory.close();
+  }
+
+  private static void scanAndRetrieve(LeafReader leaf, DocIdSetIterator iter) throws IOException {
+    for (iter.nextDoc();
+        iter.docID() != DocIdSetIterator.NO_MORE_DOCS && iter.docID() < leaf.maxDoc(); ) {
+      final int nextDocId = iter.docID() + 1;
+      if (random().nextBoolean() && nextDocId < leaf.maxDoc()) {
+        iter.advance(nextDocId);
+      } else {
+        iter.nextDoc();
+      }
+
+      if (random().nextBoolean()
+          && iter.docID() != DocIdSetIterator.NO_MORE_DOCS
+          && iter instanceof VectorValues) {
+        ((VectorValues) iter).vectorValue();
+      }
+    }
+  }
+
   private static void scan(LeafReader leaf, DocValuesIterator iter) throws IOException {
     for (iter.nextDoc();
         iter.docID() != DocIdSetIterator.NO_MORE_DOCS && iter.docID() < leaf.maxDoc(); ) {