You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by be...@apache.org on 2023/05/12 14:59:40 UTC

[lucene] branch branch_9x updated: Backport: Concurrent rewrite for KnnVectorQuery (#12160) (#12288)

This is an automated email from the ASF dual-hosted git repository.

benwtrent pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new 972de546e9b Backport: Concurrent rewrite for KnnVectorQuery (#12160) (#12288)
972de546e9b is described below

commit 972de546e9be2300720deaa456c8ac38e7e9f423
Author: Benjamin Trent <be...@gmail.com>
AuthorDate: Fri May 12 10:59:34 2023 -0400

    Backport: Concurrent rewrite for KnnVectorQuery (#12160) (#12288)
    
    * Concurrent rewrite for KnnVectorQuery (#12160)
    
    
    - Reduce overhead of non-concurrent search by preserving original execution
    - Improve readability by factoring into separate functions
    
    ---------
    
    Co-authored-by: Kaival Parikh <ka...@gmail.com>
    
    * adjusting for backport
    
    ---------
    
    Co-authored-by: Kaival Parikh <46...@users.noreply.github.com>
    Co-authored-by: Kaival Parikh <ka...@gmail.com>
---
 lucene/CHANGES.txt                                 |  2 +
 .../lucene/search/AbstractKnnVectorQuery.java      | 73 ++++++++++++++++++----
 .../lucene/search/BaseKnnVectorQueryTestCase.java  |  9 ++-
 3 files changed, 72 insertions(+), 12 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 2f3a572a1bf..8875a3f0953 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -25,6 +25,8 @@ Optimizations
 
 * GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun)
 
+* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)
+
 Bug Fixes
 ---------------------
 (No changes)
diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index 9403a6413e2..d6b9e04b542 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -21,7 +21,12 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.Comparator;
+import java.util.List;
 import java.util.Objects;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.FutureTask;
+import java.util.stream.Collectors;
 import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.index.FieldInfo;
 import org.apache.lucene.index.IndexReader;
@@ -29,6 +34,7 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.util.BitSet;
 import org.apache.lucene.util.BitSetIterator;
 import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.ThreadInterruptedException;
 
 /**
  * Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
@@ -62,9 +68,8 @@ abstract class AbstractKnnVectorQuery extends Query {
   @Override
   public Query rewrite(IndexSearcher indexSearcher) throws IOException {
     IndexReader reader = indexSearcher.getIndexReader();
-    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
 
-    Weight filterWeight = null;
+    final Weight filterWeight;
     if (filter != null) {
       BooleanQuery booleanQuery =
           new BooleanQuery.Builder()
@@ -73,17 +78,16 @@ abstract class AbstractKnnVectorQuery extends Query {
               .build();
       Query rewritten = indexSearcher.rewrite(booleanQuery);
       filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
+    } else {
+      filterWeight = null;
     }
 
-    for (LeafReaderContext ctx : reader.leaves()) {
-      TopDocs results = searchLeaf(ctx, filterWeight);
-      if (ctx.docBase > 0) {
-        for (ScoreDoc scoreDoc : results.scoreDocs) {
-          scoreDoc.doc += ctx.docBase;
-        }
-      }
-      perLeafResults[ctx.ord] = results;
-    }
+    Executor executor = indexSearcher.getExecutor();
+    TopDocs[] perLeafResults =
+        (executor == null)
+            ? sequentialSearch(reader.leaves(), filterWeight)
+            : parallelSearch(reader.leaves(), filterWeight, executor);
+
     // Merge sort the results
     TopDocs topK = TopDocs.merge(k, perLeafResults);
     if (topK.scoreDocs.length == 0) {
@@ -92,7 +96,54 @@ abstract class AbstractKnnVectorQuery extends Query {
     return createRewrittenQuery(reader, topK);
   }
 
+  private TopDocs[] sequentialSearch(
+      List<LeafReaderContext> leafReaderContexts, Weight filterWeight) {
+    try {
+      TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
+      for (LeafReaderContext ctx : leafReaderContexts) {
+        perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
+      }
+      return perLeafResults;
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  private TopDocs[] parallelSearch(
+      List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
+    List<FutureTask<TopDocs>> tasks =
+        leafReaderContexts.stream()
+            .map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight)))
+            .collect(Collectors.toList());
+
+    SliceExecutor sliceExecutor = new SliceExecutor(executor);
+    sliceExecutor.invokeAll(tasks);
+
+    return tasks.stream()
+        .map(
+            task -> {
+              try {
+                return task.get();
+              } catch (ExecutionException e) {
+                throw new RuntimeException(e.getCause());
+              } catch (InterruptedException e) {
+                throw new ThreadInterruptedException(e);
+              }
+            })
+        .toArray(TopDocs[]::new);
+  }
+
   private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
+    TopDocs results = getLeafResults(ctx, filterWeight);
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
     Bits liveDocs = ctx.reader().getLiveDocs();
     int maxDoc = ctx.reader().maxDoc();
 
diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
index af91b19444b..24de1a463c7 100644
--- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
@@ -210,7 +210,10 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
       IndexSearcher searcher = newSearcher(reader);
       AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
       IllegalArgumentException e =
-          expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
+          expectThrows(
+              RuntimeException.class,
+              IllegalArgumentException.class,
+              () -> searcher.search(kvq, 10));
       assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
     }
   }
@@ -529,6 +532,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
           assertEquals(9, results.totalHits.value);
           assertEquals(results.totalHits.value, results.scoreDocs.length);
           expectThrows(
+              RuntimeException.class,
               UnsupportedOperationException.class,
               () ->
                   searcher.search(
@@ -543,6 +547,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
           assertEquals(5, results.totalHits.value);
           assertEquals(results.totalHits.value, results.scoreDocs.length);
           expectThrows(
+              RuntimeException.class,
               UnsupportedOperationException.class,
               () ->
                   searcher.search(
@@ -570,6 +575,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
           // Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
           Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
           expectThrows(
+              RuntimeException.class,
               UnsupportedOperationException.class,
               () ->
                   searcher.search(
@@ -742,6 +748,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
 
         Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
         expectThrows(
+            RuntimeException.class,
             UnsupportedOperationException.class,
             () ->
                 searcher.search(