You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by rm...@apache.org on 2021/10/02 19:24:53 UTC
[lucene] branch main updated: LUCENE-10142: use a better RNG for
HNSW vectors
This is an automated email from the ASF dual-hosted git repository.
rmuir pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/main by this push:
new b4fcdd9 LUCENE-10142: use a better RNG for HNSW vectors
b4fcdd9 is described below
commit b4fcdd9770ef1abcac1287c0751b56ada6dde75a
Author: Robert Muir <rm...@apache.org>
AuthorDate: Sat Oct 2 15:23:28 2021 -0400
LUCENE-10142: use a better RNG for HNSW vectors
This code makes extensive use of Random, but uses the old legacy
java.util.Random, which is slow. Swap in SplittableRandom for better
performance.
---
.../lucene/codecs/lucene90/Lucene90HnswVectorsReader.java | 4 ++--
.../src/java/org/apache/lucene/util/hnsw/HnswGraph.java | 4 ++--
.../java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java | 6 +++---
.../test/org/apache/lucene/util/hnsw/TestHnswGraph.java | 14 +++++++++++---
4 files changed, 18 insertions(+), 10 deletions(-)
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
index 70e386d..9a1c16f 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -24,7 +24,7 @@ import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
-import java.util.Random;
+import java.util.SplittableRandom;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.CorruptIndexException;
@@ -242,7 +242,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
OffHeapVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
// use a seed that is fixed for the index so we get reproducible results for the same query
- final Random random = new Random(checksumSeed);
+ final SplittableRandom random = new SplittableRandom(checksumSeed);
NeighborQueue results =
HnswGraph.search(
target,
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
index 43ed6ec..511f889 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
@@ -22,7 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
-import java.util.Random;
+import java.util.SplittableRandom;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
@@ -97,7 +97,7 @@ public final class HnswGraph extends KnnGraphValues {
VectorSimilarityFunction similarityFunction,
KnnGraphValues graphValues,
Bits acceptOrds,
- Random random)
+ SplittableRandom random)
throws IOException {
int size = graphValues.size();
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index d12a731..f5cfc6a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -20,7 +20,7 @@ package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
-import java.util.Random;
+import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
@@ -45,7 +45,7 @@ public final class HnswGraphBuilder {
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
- private final Random random;
+ private final SplittableRandom random;
private final BoundsChecker bound;
final HnswGraph hnsw;
@@ -86,7 +86,7 @@ public final class HnswGraphBuilder {
this.beamWidth = beamWidth;
this.hnsw = new HnswGraph(maxConn);
bound = BoundsChecker.create(similarityFunction.reversed);
- random = new Random(seed);
+ random = new SplittableRandom(seed);
scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index 28dc517..16a3e60 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
+import java.util.SplittableRandom;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
@@ -141,7 +142,7 @@ public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
null,
- random());
+ new SplittableRandom(random().nextLong()));
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
@@ -182,7 +183,7 @@ public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
acceptOrds,
- random());
+ new SplittableRandom(random().nextLong()));
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
int sum = 0;
@@ -325,7 +326,14 @@ public class TestHnswGraph extends LuceneTestCase {
float[] query = randomVector(random(), dim);
NeighborQueue actual =
HnswGraph.search(
- query, topK, 100, vectors, similarityFunction, hnsw, acceptOrds, random());
+ query,
+ topK,
+ 100,
+ vectors,
+ similarityFunction,
+ hnsw,
+ acceptOrds,
+ new SplittableRandom(random().nextLong()));
NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
for (int j = 0; j < size; j++) {
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {