You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by so...@apache.org on 2022/05/04 22:22:55 UTC

[lucene] branch main updated: LUCENE-10504: KnnGraphTester to use KnnVectorQuery (#796)

This is an automated email from the ASF dual-hosted git repository.

sokolov pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new 7fbaa63dd1f LUCENE-10504: KnnGraphTester to use KnnVectorQuery (#796)
7fbaa63dd1f is described below

commit 7fbaa63dd1f3eb441ce966e0876abace4310e8fe
Author: Michael Sokolov <so...@falutin.net>
AuthorDate: Wed May 4 18:22:48 2022 -0400

    LUCENE-10504: KnnGraphTester to use KnnVectorQuery (#796)
    
    * LUCENE-10504: KnnGraphTester to use KnnVectorQuery
---
 .../apache/lucene/util/hnsw/KnnGraphTester.java    | 52 +++++++++++-----------
 1 file changed, 26 insertions(+), 26 deletions(-)

diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index 57389d098c9..a14bcd73be1 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -32,8 +32,10 @@ import java.nio.channels.FileChannel;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Paths;
+import java.nio.file.attribute.FileTime;
 import java.util.HashSet;
 import java.util.Locale;
+import java.util.Objects;
 import java.util.Set;
 import org.apache.lucene.codecs.KnnVectorsFormat;
 import org.apache.lucene.codecs.KnnVectorsReader;
@@ -47,7 +49,6 @@ import org.apache.lucene.document.KnnVectorField;
 import org.apache.lucene.document.StoredField;
 import org.apache.lucene.index.CodecReader;
 import org.apache.lucene.index.DirectoryReader;
-import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.LeafReader;
@@ -55,11 +56,12 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.RandomAccessVectorValues;
 import org.apache.lucene.index.RandomAccessVectorValuesProducer;
 import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnVectorQuery;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.FSDirectory;
-import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.IntroSorter;
 import org.apache.lucene.util.PrintStreamInfoStream;
@@ -79,7 +81,6 @@ public class KnnGraphTester {
   private int numDocs;
   private int dim;
   private int topK;
-  private int warmCount;
   private int numIters;
   private int fanout;
   private Path indexPath;
@@ -98,7 +99,6 @@ public class KnnGraphTester {
     numIters = 1000;
     dim = 256;
     topK = 100;
-    warmCount = 1000;
     fanout = topK;
     similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
   }
@@ -178,9 +178,6 @@ public class KnnGraphTester {
         case "-out":
           outputPath = Paths.get(args[++iarg]);
           break;
-        case "-warm":
-          warmCount = Integer.parseInt(args[++iarg]);
-          break;
         case "-docs":
           docVectorsPath = Paths.get(args[++iarg]);
           break;
@@ -350,8 +347,9 @@ public class KnnGraphTester {
     TopDocs[] results = new TopDocs[numIters];
     long elapsed, totalCpuTime, totalVisited = 0;
     try (FileChannel q = FileChannel.open(queryPath)) {
+      int bufferSize = numIters * dim * Float.BYTES;
       FloatBuffer targets =
-          q.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
+          q.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize)
               .order(ByteOrder.LITTLE_ENDIAN)
               .asFloatBuffer();
       float[] target = new float[dim];
@@ -363,18 +361,19 @@ public class KnnGraphTester {
       long cpuTimeStartNs;
       try (Directory dir = FSDirectory.open(indexPath);
           DirectoryReader reader = DirectoryReader.open(dir)) {
+        IndexSearcher searcher = new IndexSearcher(reader);
         numDocs = reader.maxDoc();
-        for (int i = 0; i < warmCount; i++) {
+        for (int i = 0; i < numIters; i++) {
           // warm up
           targets.get(target);
-          results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
+          doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
         }
         targets.position(0);
         start = System.nanoTime();
         cpuTimeStartNs = bean.getCurrentThreadCpuTime();
         for (int i = 0; i < numIters; i++) {
           targets.get(target);
-          results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
+          results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
         }
         totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
         elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms
@@ -431,19 +430,9 @@ public class KnnGraphTester {
     }
   }
 
-  private static TopDocs doKnnSearch(
-      IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException {
-    TopDocs[] results = new TopDocs[reader.leaves().size()];
-    for (LeafReaderContext ctx : reader.leaves()) {
-      Bits liveDocs = ctx.reader().getLiveDocs();
-      results[ctx.ord] =
-          ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs, Integer.MAX_VALUE);
-      int docBase = ctx.docBase;
-      for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
-        scoreDoc.doc += docBase;
-      }
-    }
-    return TopDocs.merge(k, results);
+  private static TopDocs doKnnVectorQuery(
+      IndexSearcher searcher, String field, float[] vector, int k, int fanout) throws IOException {
+    return searcher.search(new KnnVectorQuery(field, vector, k + fanout), k);
   }
 
   private float checkResults(TopDocs[] results, int[][] nn) {
@@ -488,9 +477,10 @@ public class KnnGraphTester {
 
   private int[][] getNN(Path docPath, Path queryPath) throws IOException {
     // look in working directory for cached nn file
-    String nnFileName = "nn-" + numDocs + "-" + numIters + "-" + topK + "-" + dim + ".bin";
+    String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK), 36);
+    String nnFileName = "nn-" + hash + ".bin";
     Path nnPath = Paths.get(nnFileName);
-    if (Files.exists(nnPath)) {
+    if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath)) {
       return readNN(nnPath);
     } else {
       int[][] nn = computeNN(docPath, queryPath);
@@ -499,6 +489,16 @@ public class KnnGraphTester {
     }
   }
 
+  private boolean isNewer(Path path, Path... others) throws IOException {
+    FileTime modified = Files.getLastModifiedTime(path);
+    for (Path other : others) {
+      if (Files.getLastModifiedTime(other).compareTo(modified) >= 0) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   private int[][] readNN(Path nnPath) throws IOException {
     int[][] result = new int[numIters][];
     try (FileChannel in = FileChannel.open(nnPath)) {