You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by so...@apache.org on 2022/05/04 22:22:55 UTC
[lucene] branch main updated: LUCENE-10504: KnnGraphTester to use KnnVectorQuery (#796)
This is an automated email from the ASF dual-hosted git repository.
sokolov pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/main by this push:
new 7fbaa63dd1f LUCENE-10504: KnnGraphTester to use KnnVectorQuery (#796)
7fbaa63dd1f is described below
commit 7fbaa63dd1f3eb441ce966e0876abace4310e8fe
Author: Michael Sokolov <so...@falutin.net>
AuthorDate: Wed May 4 18:22:48 2022 -0400
LUCENE-10504: KnnGraphTester to use KnnVectorQuery (#796)
* LUCENE-10504: KnnGraphTester to use KnnVectorQuery
---
.../apache/lucene/util/hnsw/KnnGraphTester.java | 52 +++++++++++-----------
1 file changed, 26 insertions(+), 26 deletions(-)
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index 57389d098c9..a14bcd73be1 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -32,8 +32,10 @@ import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
+import java.nio.file.attribute.FileTime;
import java.util.HashSet;
import java.util.Locale;
+import java.util.Objects;
import java.util.Set;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
@@ -47,7 +49,6 @@ import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
@@ -55,11 +56,12 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
-import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.PrintStreamInfoStream;
@@ -79,7 +81,6 @@ public class KnnGraphTester {
private int numDocs;
private int dim;
private int topK;
- private int warmCount;
private int numIters;
private int fanout;
private Path indexPath;
@@ -98,7 +99,6 @@ public class KnnGraphTester {
numIters = 1000;
dim = 256;
topK = 100;
- warmCount = 1000;
fanout = topK;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
}
@@ -178,9 +178,6 @@ public class KnnGraphTester {
case "-out":
outputPath = Paths.get(args[++iarg]);
break;
- case "-warm":
- warmCount = Integer.parseInt(args[++iarg]);
- break;
case "-docs":
docVectorsPath = Paths.get(args[++iarg]);
break;
@@ -350,8 +347,9 @@ public class KnnGraphTester {
TopDocs[] results = new TopDocs[numIters];
long elapsed, totalCpuTime, totalVisited = 0;
try (FileChannel q = FileChannel.open(queryPath)) {
+ int bufferSize = numIters * dim * Float.BYTES;
FloatBuffer targets =
- q.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
+ q.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
float[] target = new float[dim];
@@ -363,18 +361,19 @@ public class KnnGraphTester {
long cpuTimeStartNs;
try (Directory dir = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(dir)) {
+ IndexSearcher searcher = new IndexSearcher(reader);
numDocs = reader.maxDoc();
- for (int i = 0; i < warmCount; i++) {
+ for (int i = 0; i < numIters; i++) {
// warm up
targets.get(target);
- results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
+ doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
}
targets.position(0);
start = System.nanoTime();
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
for (int i = 0; i < numIters; i++) {
targets.get(target);
- results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout);
+ results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout);
}
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms
@@ -431,19 +430,9 @@ public class KnnGraphTester {
}
}
- private static TopDocs doKnnSearch(
- IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException {
- TopDocs[] results = new TopDocs[reader.leaves().size()];
- for (LeafReaderContext ctx : reader.leaves()) {
- Bits liveDocs = ctx.reader().getLiveDocs();
- results[ctx.ord] =
- ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs, Integer.MAX_VALUE);
- int docBase = ctx.docBase;
- for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) {
- scoreDoc.doc += docBase;
- }
- }
- return TopDocs.merge(k, results);
+ private static TopDocs doKnnVectorQuery(
+ IndexSearcher searcher, String field, float[] vector, int k, int fanout) throws IOException {
+ return searcher.search(new KnnVectorQuery(field, vector, k + fanout), k);
}
private float checkResults(TopDocs[] results, int[][] nn) {
@@ -488,9 +477,10 @@ public class KnnGraphTester {
private int[][] getNN(Path docPath, Path queryPath) throws IOException {
// look in working directory for cached nn file
- String nnFileName = "nn-" + numDocs + "-" + numIters + "-" + topK + "-" + dim + ".bin";
+ String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK), 36);
+ String nnFileName = "nn-" + hash + ".bin";
Path nnPath = Paths.get(nnFileName);
- if (Files.exists(nnPath)) {
+ if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath)) {
return readNN(nnPath);
} else {
int[][] nn = computeNN(docPath, queryPath);
@@ -499,6 +489,16 @@ public class KnnGraphTester {
}
}
+ private boolean isNewer(Path path, Path... others) throws IOException {
+ FileTime modified = Files.getLastModifiedTime(path);
+ for (Path other : others) {
+ if (Files.getLastModifiedTime(other).compareTo(modified) >= 0) {
+ return false;
+ }
+ }
+ return true;
+ }
+
private int[][] readNN(Path nnPath) throws IOException {
int[][] result = new int[numIters][];
try (FileChannel in = FileChannel.open(nnPath)) {