You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by rm...@apache.org on 2021/10/02 19:24:53 UTC

[lucene] branch main updated: LUCENE-10142: use a better RNG for HNSW vectors

This is an automated email from the ASF dual-hosted git repository.

rmuir pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new b4fcdd9  LUCENE-10142: use a better RNG for HNSW vectors
b4fcdd9 is described below

commit b4fcdd9770ef1abcac1287c0751b56ada6dde75a
Author: Robert Muir <rm...@apache.org>
AuthorDate: Sat Oct 2 15:23:28 2021 -0400

    LUCENE-10142: use a better RNG for HNSW vectors
    
    This code makes extensive use of Random, but uses the old legacy
    java.util.Random, which is slow. Swap in SplittableRandom for better
    performance.
---
 .../lucene/codecs/lucene90/Lucene90HnswVectorsReader.java  |  4 ++--
 .../src/java/org/apache/lucene/util/hnsw/HnswGraph.java    |  4 ++--
 .../java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java |  6 +++---
 .../test/org/apache/lucene/util/hnsw/TestHnswGraph.java    | 14 +++++++++++---
 4 files changed, 18 insertions(+), 10 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
index 70e386d..9a1c16f 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -24,7 +24,7 @@ import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.Random;
+import java.util.SplittableRandom;
 import org.apache.lucene.codecs.CodecUtil;
 import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.index.CorruptIndexException;
@@ -242,7 +242,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
     OffHeapVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
 
     // use a seed that is fixed for the index so we get reproducible results for the same query
-    final Random random = new Random(checksumSeed);
+    final SplittableRandom random = new SplittableRandom(checksumSeed);
     NeighborQueue results =
         HnswGraph.search(
             target,
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
index 43ed6ec..511f889 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
@@ -22,7 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Random;
+import java.util.SplittableRandom;
 import org.apache.lucene.index.KnnGraphValues;
 import org.apache.lucene.index.RandomAccessVectorValues;
 import org.apache.lucene.index.VectorSimilarityFunction;
@@ -97,7 +97,7 @@ public final class HnswGraph extends KnnGraphValues {
       VectorSimilarityFunction similarityFunction,
       KnnGraphValues graphValues,
       Bits acceptOrds,
-      Random random)
+      SplittableRandom random)
       throws IOException {
     int size = graphValues.size();
 
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index d12a731..f5cfc6a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -20,7 +20,7 @@ package org.apache.lucene.util.hnsw;
 import java.io.IOException;
 import java.util.Locale;
 import java.util.Objects;
-import java.util.Random;
+import java.util.SplittableRandom;
 import org.apache.lucene.index.RandomAccessVectorValues;
 import org.apache.lucene.index.RandomAccessVectorValuesProducer;
 import org.apache.lucene.index.VectorSimilarityFunction;
@@ -45,7 +45,7 @@ public final class HnswGraphBuilder {
 
   private final VectorSimilarityFunction similarityFunction;
   private final RandomAccessVectorValues vectorValues;
-  private final Random random;
+  private final SplittableRandom random;
   private final BoundsChecker bound;
   final HnswGraph hnsw;
 
@@ -86,7 +86,7 @@ public final class HnswGraphBuilder {
     this.beamWidth = beamWidth;
     this.hnsw = new HnswGraph(maxConn);
     bound = BoundsChecker.create(similarityFunction.reversed);
-    random = new Random(seed);
+    random = new SplittableRandom(seed);
     scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
   }
 
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index 28dc517..16a3e60 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Random;
 import java.util.Set;
+import java.util.SplittableRandom;
 import org.apache.lucene.codecs.KnnVectorsFormat;
 import org.apache.lucene.codecs.lucene90.Lucene90Codec;
 import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
@@ -141,7 +142,7 @@ public class TestHnswGraph extends LuceneTestCase {
             VectorSimilarityFunction.DOT_PRODUCT,
             hnsw,
             null,
-            random());
+            new SplittableRandom(random().nextLong()));
 
     int[] nodes = nn.nodes();
     assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
@@ -182,7 +183,7 @@ public class TestHnswGraph extends LuceneTestCase {
             VectorSimilarityFunction.DOT_PRODUCT,
             hnsw,
             acceptOrds,
-            random());
+            new SplittableRandom(random().nextLong()));
     int[] nodes = nn.nodes();
     assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
     int sum = 0;
@@ -325,7 +326,14 @@ public class TestHnswGraph extends LuceneTestCase {
       float[] query = randomVector(random(), dim);
       NeighborQueue actual =
           HnswGraph.search(
-              query, topK, 100, vectors, similarityFunction, hnsw, acceptOrds, random());
+              query,
+              topK,
+              100,
+              vectors,
+              similarityFunction,
+              hnsw,
+              acceptOrds,
+              new SplittableRandom(random().nextLong()));
       NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
       for (int j = 0; j < size; j++) {
         if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {