You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by be...@apache.org on 2023/02/21 13:27:02 UTC

[lucene] branch branch_9x updated: Minor vector search matching doc optimizations (#12152)

This is an automated email from the ASF dual-hosted git repository.

benwtrent pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new 8e08036e411 Minor vector search matching doc optimizations (#12152)
8e08036e411 is described below

commit 8e08036e411e13f4a4f6265566fa3f731da425e6
Author: Benjamin Trent <be...@gmail.com>
AuthorDate: Tue Feb 21 07:51:03 2023 -0500

    Minor vector search matching doc optimizations (#12152)
    
    The two minor performance improvements are around count and Weight#scorer.
    segmentStarts is a monotonically increasing start for each scored document indexed by leaf-segment ordinal. Consequently, if the upper and lower segments are equivalent, that means no docs match for this segment.
    
    Count is similarly calculated by the difference between upper and lower segmentStarts according to the segment ordinal.
---
 .../lucene/search/AbstractKnnVectorQuery.java      | 11 +++-
 .../lucene/search/TestKnnFloatVectorQuery.java     | 72 ++++++++++++++++++++++
 2 files changed, 81 insertions(+), 2 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index 10681203dcb..eaa7cc1e833 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -200,7 +200,7 @@ abstract class AbstractKnnVectorQuery extends Query {
     return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id());
   }
 
-  private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+  static int[] findSegmentStarts(IndexReader reader, int[] docs) {
     int[] starts = new int[reader.leaves().size() + 1];
     starts[starts.length - 1] = docs.length;
     if (starts.length == 2) {
@@ -308,8 +308,15 @@ abstract class AbstractKnnVectorQuery extends Query {
         }
 
         @Override
-        public Scorer scorer(LeafReaderContext context) {
+        public int count(LeafReaderContext context) {
+          return segmentStarts[context.ord + 1] - segmentStarts[context.ord];
+        }
 
+        @Override
+        public Scorer scorer(LeafReaderContext context) {
+          if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) {
+            return null;
+          }
           return new Scorer(this) {
             final int lower = segmentStarts[context.ord];
             final int upper = segmentStarts[context.ord + 1];
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
index 4c9cdd47d21..ea26278c64f 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
@@ -16,13 +16,19 @@
  */
 package org.apache.lucene.search;
 
+import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat;
 import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
 import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
 import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.document.StringField;
 import org.apache.lucene.index.DirectoryReader;
 import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.IndexWriter;
@@ -30,6 +36,8 @@ import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.util.TestVectorUtil;
 import org.apache.lucene.util.VectorUtil;
 
@@ -159,6 +167,70 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
     }
   }
 
+  public void testDocAndScoreQueryBasics() throws IOException {
+    try (Directory directory = newDirectory()) {
+      final DirectoryReader reader;
+      try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) {
+        for (int i = 0; i < 50; i++) {
+          Document doc = new Document();
+          doc.add(new StringField("field", "value" + i, Field.Store.NO));
+          iw.addDocument(doc);
+          if (i % 10 == 0) {
+            iw.flush();
+          }
+        }
+        reader = iw.getReader();
+      }
+      try (reader) {
+        IndexSearcher searcher = LuceneTestCase.newSearcher(reader);
+        List<ScoreDoc> scoreDocsList = new ArrayList<>();
+        for (int doc = 0; doc < 30; doc += 1 + random().nextInt(5)) {
+          scoreDocsList.add(new ScoreDoc(doc, randomFloat()));
+        }
+        ScoreDoc[] scoreDocs = scoreDocsList.toArray(new ScoreDoc[0]);
+        int[] docs = new int[scoreDocs.length];
+        float[] scores = new float[scoreDocs.length];
+        float maxScore = Float.MIN_VALUE;
+        for (int i = 0; i < scoreDocs.length; i++) {
+          docs[i] = scoreDocs[i].doc;
+          scores[i] = scoreDocs[i].score;
+          maxScore = Math.max(maxScore, scores[i]);
+        }
+        int[] segments = AbstractKnnVectorQuery.findSegmentStarts(reader, docs);
+
+        AbstractKnnVectorQuery.DocAndScoreQuery query =
+            new AbstractKnnVectorQuery.DocAndScoreQuery(
+                scoreDocs.length, docs, scores, segments, reader.getContext().id());
+        final Weight w = query.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f);
+        TopDocs topDocs = searcher.search(query, 100);
+        assertEquals(scoreDocs.length, topDocs.totalHits.value);
+        assertEquals(TotalHits.Relation.EQUAL_TO, topDocs.totalHits.relation);
+        Arrays.sort(topDocs.scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
+        assertEquals(scoreDocs.length, topDocs.scoreDocs.length);
+        for (int i = 0; i < scoreDocs.length; i++) {
+          assertEquals(scoreDocs[i].doc, topDocs.scoreDocs[i].doc);
+          assertEquals(scoreDocs[i].score, topDocs.scoreDocs[i].score, 0.0001f);
+          assertTrue(searcher.explain(query, scoreDocs[i].doc).isMatch());
+        }
+
+        for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
+          final Scorer scorer = w.scorer(leafReaderContext);
+          final int count = w.count(leafReaderContext);
+          if (scorer == null) {
+            assertEquals(0, count);
+          } else {
+            assertTrue(count > 0);
+            int iteratorCount = 0;
+            while (scorer.iterator().nextDoc() != NO_MORE_DOCS) {
+              iteratorCount++;
+            }
+            assertEquals(iteratorCount, count);
+          }
+        }
+      }
+    }
+  }
+
   private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery {
 
     public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) {