You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by pa...@apache.org on 2023/05/22 04:49:34 UTC

[lucene] branch branch_9x updated: Add multi-thread searchability to OnHeapHnswGraph (#12257)

This is an automated email from the ASF dual-hosted git repository.

patrickz pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new 25a908d015e Add multi-thread searchability to OnHeapHnswGraph (#12257)
25a908d015e is described below

commit 25a908d015eecdbf1441e8466c208b0e58a1b0f6
Author: Patrick Zhai <zh...@users.noreply.github.com>
AuthorDate: Sun May 21 21:48:46 2023 -0700

    Add multi-thread searchability to OnHeapHnswGraph (#12257)
---
 lucene/CHANGES.txt                                 |   3 +-
 .../apache/lucene/util/hnsw/HnswGraphSearcher.java | 145 +++++++++++++++++----
 .../apache/lucene/util/hnsw/HnswGraphTestCase.java | 106 +++++++++++++++
 3 files changed, 230 insertions(+), 24 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 4bd42ace6d0..ebd8486df15 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -17,7 +17,8 @@ API Changes
 
 New Features
 ---------------------
-(No changes)
+
+* GITHUB#12257: Create OnHeapHnswGraphSearcher to let OnHeapHnswGraph to be searched in a thread-safety manner. (Patrick Zhai)
 
 Improvements
 ---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index 4857d5b9d57..d6e63f483b2 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -100,28 +100,31 @@ public class HnswGraphSearcher<T> {
             similarityFunction,
             new NeighborQueue(topK, true),
             new SparseFixedBitSet(vectors.size()));
-    NeighborQueue results;
+    return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
+  }
 
-    int initialEp = graph.entryNode();
-    if (initialEp == -1) {
-      return new NeighborQueue(1, true);
-    }
-    int[] eps = new int[] {initialEp};
-    int numVisited = 0;
-    for (int level = graph.numLevels() - 1; level >= 1; level--) {
-      results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
-      numVisited += results.visitedCount();
-      visitedLimit -= results.visitedCount();
-      if (results.incomplete()) {
-        results.setVisitedCount(numVisited);
-        return results;
-      }
-      eps[0] = results.pop();
-    }
-    results =
-        graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
-    results.setVisitedCount(results.visitedCount() + numVisited);
-    return results;
+  /**
+   * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
+   * {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding,
+   * VectorSimilarityFunction, HnswGraph, Bits, int)}
+   */
+  public static NeighborQueue search(
+      float[] query,
+      int topK,
+      RandomAccessVectorValues<float[]> vectors,
+      VectorEncoding vectorEncoding,
+      VectorSimilarityFunction similarityFunction,
+      OnHeapHnswGraph graph,
+      Bits acceptOrds,
+      int visitedLimit)
+      throws IOException {
+    OnHeapHnswGraphSearcher<float[]> graphSearcher =
+        new OnHeapHnswGraphSearcher<>(
+            vectorEncoding,
+            similarityFunction,
+            new NeighborQueue(topK, true),
+            new SparseFixedBitSet(vectors.size()));
+    return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
   }
 
   /**
@@ -161,6 +164,46 @@ public class HnswGraphSearcher<T> {
             similarityFunction,
             new NeighborQueue(topK, true),
             new SparseFixedBitSet(vectors.size()));
+    return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
+  }
+
+  /**
+   * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
+   * {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction,
+   * HnswGraph, Bits, int)}
+   */
+  public static NeighborQueue search(
+      byte[] query,
+      int topK,
+      RandomAccessVectorValues<byte[]> vectors,
+      VectorEncoding vectorEncoding,
+      VectorSimilarityFunction similarityFunction,
+      OnHeapHnswGraph graph,
+      Bits acceptOrds,
+      int visitedLimit)
+      throws IOException {
+    OnHeapHnswGraphSearcher<byte[]> graphSearcher =
+        new OnHeapHnswGraphSearcher<>(
+            vectorEncoding,
+            similarityFunction,
+            new NeighborQueue(topK, true),
+            new SparseFixedBitSet(vectors.size()));
+    return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
+  }
+
+  private static <T> NeighborQueue search(
+      T query,
+      int topK,
+      RandomAccessVectorValues<T> vectors,
+      HnswGraph graph,
+      HnswGraphSearcher<T> graphSearcher,
+      Bits acceptOrds,
+      int visitedLimit)
+      throws IOException {
+    int initialEp = graph.entryNode();
+    if (initialEp == -1) {
+      return new NeighborQueue(1, true);
+    }
     NeighborQueue results;
     int[] eps = new int[] {graph.entryNode()};
     int numVisited = 0;
@@ -252,9 +295,9 @@ public class HnswGraphSearcher<T> {
       }
 
       int topCandidateNode = candidates.pop();
-      graph.seek(level, topCandidateNode);
+      graphSeek(graph, level, topCandidateNode);
       int friendOrd;
-      while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
+      while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) {
         assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
         if (visited.getAndSet(friendOrd)) {
           continue;
@@ -298,4 +341,60 @@ public class HnswGraphSearcher<T> {
     }
     visited.clear(0, visited.length());
   }
+
+  /**
+   * Seek a specific node in the given graph. The default implementation will just call {@link
+   * HnswGraph#seek(int, int)}
+   *
+   * @throws IOException when seeking the graph
+   */
+  void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException {
+    graph.seek(level, targetNode);
+  }
+
+  /**
+   * Get the next neighbor from the graph, you must call {@link #graphSeek(HnswGraph, int, int)}
+   * before calling this method. The default implementation will just call {@link
+   * HnswGraph#nextNeighbor()}
+   *
+   * @return see {@link HnswGraph#nextNeighbor()}
+   * @throws IOException when advance neighbors
+   */
+  int graphNextNeighbor(HnswGraph graph) throws IOException {
+    return graph.nextNeighbor();
+  }
+
+  /**
+   * This class allow {@link OnHeapHnswGraph} to be searched in a thread-safe manner.
+   *
+   * <p>Note the class itself is NOT thread safe, but since each search will create one new graph
+   * searcher the search method is thread safe.
+   */
+  private static class OnHeapHnswGraphSearcher<C> extends HnswGraphSearcher<C> {
+
+    private NeighborArray cur;
+    private int upto;
+
+    private OnHeapHnswGraphSearcher(
+        VectorEncoding vectorEncoding,
+        VectorSimilarityFunction similarityFunction,
+        NeighborQueue candidates,
+        BitSet visited) {
+      super(vectorEncoding, similarityFunction, candidates, visited);
+    }
+
+    @Override
+    void graphSeek(HnswGraph graph, int level, int targetNode) {
+      cur = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode);
+      upto = -1;
+    }
+
+    @Override
+    int graphNextNeighbor(HnswGraph graph) {
+      if (++upto < cur.size()) {
+        return cur.node[upto];
+      }
+      return NO_MORE_DOCS;
+    }
+  }
 }
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
index 63aebc40dd7..4cbcde3f362 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
@@ -32,6 +32,12 @@ import java.util.List;
 import java.util.Map;
 import java.util.Random;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.stream.Collectors;
 import org.apache.lucene.codecs.KnnVectorsFormat;
 import org.apache.lucene.codecs.lucene95.Lucene95Codec;
@@ -66,6 +72,7 @@ import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.BitSet;
 import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.FixedBitSet;
+import org.apache.lucene.util.NamedThreadFactory;
 import org.apache.lucene.util.RamUsageEstimator;
 import org.apache.lucene.util.VectorUtil;
 import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
@@ -990,6 +997,105 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
     assertTrue("overlap=" + overlap, overlap > 0.9);
   }
 
+  /* test thread-safety of searching OnHeapHnswGraph */
+  @SuppressWarnings("unchecked")
+  public void testOnHeapHnswGraphSearch()
+      throws IOException, ExecutionException, InterruptedException, TimeoutException {
+    int size = atLeast(100);
+    int dim = atLeast(10);
+    AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
+    int topK = 5;
+    HnswGraphBuilder<T> builder =
+        HnswGraphBuilder.create(
+            vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong());
+    OnHeapHnswGraph hnsw = builder.build(vectors.copy());
+    Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
+
+    List<T> queries = new ArrayList<>();
+    List<NeighborQueue> expects = new ArrayList<>();
+    for (int i = 0; i < 100; i++) {
+      NeighborQueue expect;
+      T query = randomVector(dim);
+      queries.add(query);
+      expect =
+          switch (getVectorEncoding()) {
+            case BYTE -> HnswGraphSearcher.search(
+                (byte[]) query,
+                100,
+                (RandomAccessVectorValues<byte[]>) vectors,
+                getVectorEncoding(),
+                similarityFunction,
+                hnsw,
+                acceptOrds,
+                Integer.MAX_VALUE);
+            case FLOAT32 -> HnswGraphSearcher.search(
+                (float[]) query,
+                100,
+                (RandomAccessVectorValues<float[]>) vectors,
+                getVectorEncoding(),
+                similarityFunction,
+                hnsw,
+                acceptOrds,
+                Integer.MAX_VALUE);
+          };
+
+      while (expect.size() > topK) {
+        expect.pop();
+      }
+      expects.add(expect);
+    }
+
+    ExecutorService exec =
+        Executors.newFixedThreadPool(4, new NamedThreadFactory("onHeapHnswSearch"));
+    List<Future<NeighborQueue>> futures = new ArrayList<>();
+    for (T query : queries) {
+      futures.add(
+          exec.submit(
+              () -> {
+                NeighborQueue actual;
+                try {
+                  actual =
+                      switch (getVectorEncoding()) {
+                        case BYTE -> HnswGraphSearcher.search(
+                            (byte[]) query,
+                            100,
+                            (RandomAccessVectorValues<byte[]>) vectors,
+                            getVectorEncoding(),
+                            similarityFunction,
+                            hnsw,
+                            acceptOrds,
+                            Integer.MAX_VALUE);
+                        case FLOAT32 -> HnswGraphSearcher.search(
+                            (float[]) query,
+                            100,
+                            (RandomAccessVectorValues<float[]>) vectors,
+                            getVectorEncoding(),
+                            similarityFunction,
+                            hnsw,
+                            acceptOrds,
+                            Integer.MAX_VALUE);
+                      };
+                } catch (IOException ioe) {
+                  throw new RuntimeException(ioe);
+                }
+                while (actual.size() > topK) {
+                  actual.pop();
+                }
+                return actual;
+              }));
+    }
+    List<NeighborQueue> actuals = new ArrayList<>();
+    for (Future<NeighborQueue> future : futures) {
+      actuals.add(future.get(10, TimeUnit.SECONDS));
+    }
+    exec.shutdownNow();
+    for (int i = 0; i < expects.size(); i++) {
+      NeighborQueue expect = expects.get(i);
+      NeighborQueue actual = actuals.get(i);
+      assertArrayEquals(expect.nodes(), actual.nodes());
+    }
+  }
+
   private int computeOverlap(int[] a, int[] b) {
     Arrays.sort(a);
     Arrays.sort(b);