You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by be...@apache.org on 2023/05/12 14:59:40 UTC
[lucene] branch branch_9x updated: Backport: Concurrent rewrite for KnnVectorQuery (#12160) (#12288)
This is an automated email from the ASF dual-hosted git repository.
benwtrent pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/branch_9x by this push:
new 972de546e9b Backport: Concurrent rewrite for KnnVectorQuery (#12160) (#12288)
972de546e9b is described below
commit 972de546e9be2300720deaa456c8ac38e7e9f423
Author: Benjamin Trent <be...@gmail.com>
AuthorDate: Fri May 12 10:59:34 2023 -0400
Backport: Concurrent rewrite for KnnVectorQuery (#12160) (#12288)
* Concurrent rewrite for KnnVectorQuery (#12160)
- Reduce overhead of non-concurrent search by preserving original execution
- Improve readability by factoring into separate functions
---------
Co-authored-by: Kaival Parikh <ka...@gmail.com>
* adjusting for backport
---------
Co-authored-by: Kaival Parikh <46...@users.noreply.github.com>
Co-authored-by: Kaival Parikh <ka...@gmail.com>
---
lucene/CHANGES.txt | 2 +
.../lucene/search/AbstractKnnVectorQuery.java | 73 ++++++++++++++++++----
.../lucene/search/BaseKnnVectorQueryTestCase.java | 9 ++-
3 files changed, 72 insertions(+), 12 deletions(-)
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 2f3a572a1bf..8875a3f0953 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -25,6 +25,8 @@ Optimizations
* GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun)
+* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)
+
Bug Fixes
---------------------
(No changes)
diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index 9403a6413e2..d6b9e04b542 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -21,7 +21,12 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
+import java.util.List;
import java.util.Objects;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.FutureTask;
+import java.util.stream.Collectors;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
@@ -29,6 +34,7 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.ThreadInterruptedException;
/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
@@ -62,9 +68,8 @@ abstract class AbstractKnnVectorQuery extends Query {
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
- TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
- Weight filterWeight = null;
+ final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
@@ -73,17 +78,16 @@ abstract class AbstractKnnVectorQuery extends Query {
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
+ } else {
+ filterWeight = null;
}
- for (LeafReaderContext ctx : reader.leaves()) {
- TopDocs results = searchLeaf(ctx, filterWeight);
- if (ctx.docBase > 0) {
- for (ScoreDoc scoreDoc : results.scoreDocs) {
- scoreDoc.doc += ctx.docBase;
- }
- }
- perLeafResults[ctx.ord] = results;
- }
+ Executor executor = indexSearcher.getExecutor();
+ TopDocs[] perLeafResults =
+ (executor == null)
+ ? sequentialSearch(reader.leaves(), filterWeight)
+ : parallelSearch(reader.leaves(), filterWeight, executor);
+
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
@@ -92,7 +96,54 @@ abstract class AbstractKnnVectorQuery extends Query {
return createRewrittenQuery(reader, topK);
}
+ private TopDocs[] sequentialSearch(
+ List<LeafReaderContext> leafReaderContexts, Weight filterWeight) {
+ try {
+ TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
+ for (LeafReaderContext ctx : leafReaderContexts) {
+ perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
+ }
+ return perLeafResults;
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private TopDocs[] parallelSearch(
+ List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
+ List<FutureTask<TopDocs>> tasks =
+ leafReaderContexts.stream()
+ .map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight)))
+ .collect(Collectors.toList());
+
+ SliceExecutor sliceExecutor = new SliceExecutor(executor);
+ sliceExecutor.invokeAll(tasks);
+
+ return tasks.stream()
+ .map(
+ task -> {
+ try {
+ return task.get();
+ } catch (ExecutionException e) {
+ throw new RuntimeException(e.getCause());
+ } catch (InterruptedException e) {
+ throw new ThreadInterruptedException(e);
+ }
+ })
+ .toArray(TopDocs[]::new);
+ }
+
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
+ TopDocs results = getLeafResults(ctx, filterWeight);
+ if (ctx.docBase > 0) {
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ scoreDoc.doc += ctx.docBase;
+ }
+ }
+ return results;
+ }
+
+ private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();
diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
index af91b19444b..24de1a463c7 100644
--- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
@@ -210,7 +210,10 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
- expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
+ expectThrows(
+ RuntimeException.class,
+ IllegalArgumentException.class,
+ () -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
}
}
@@ -529,6 +532,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
+ RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
@@ -543,6 +547,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
+ RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
@@ -570,6 +575,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
expectThrows(
+ RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
@@ -742,6 +748,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
expectThrows(
+ RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(