You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ma...@apache.org on 2021/08/23 19:54:31 UTC

[lucene] branch main updated: LUCENE-10054 Make HnswGraph hierarchical (#250)

This is an automated email from the ASF dual-hosted git repository.

mayya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new 257d256  LUCENE-10054 Make HnswGraph hierarchical (#250)
257d256 is described below

commit 257d256defc47c446493ea99b841f58c543673c0
Author: Mayya Sharipova <ma...@elastic.co>
AuthorDate: Mon Aug 23 15:54:26 2021 -0400

    LUCENE-10054 Make HnswGraph hierarchical (#250)
    
    Currently HNSW has only a single layer.
    This is the first part to make it multi-layered.
    
    To keep changes small, this PR only adds
     multiple layers in the HnswGraph class.
    
    TODO  for following PRs:
    - modify graph construction and search algorithm for a hierarchical
    graph.
    - modify Lucene90HnswVectorsWriter and Lucene90HnswVectorsReader to
    write and read multiple layers\
---
 .../codecs/lucene90/Lucene90HnswVectorsReader.java |  2 +-
 .../codecs/lucene90/Lucene90HnswVectorsWriter.java |  3 +-
 .../org/apache/lucene/index/KnnGraphValues.java    |  7 +--
 .../org/apache/lucene/util/hnsw/HnswGraph.java     | 60 ++++++++++++++--------
 .../apache/lucene/util/hnsw/HnswGraphBuilder.java  | 13 ++---
 .../test/org/apache/lucene/index/TestKnnGraph.java |  4 +-
 .../apache/lucene/util/hnsw/KnnGraphTester.java    |  6 +--
 .../org/apache/lucene/util/hnsw/TestHnswGraph.java | 18 +++----
 8 files changed, 67 insertions(+), 46 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
index 70e386d..726bc4c 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -481,7 +481,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
     }
 
     @Override
-    public void seek(int targetOrd) throws IOException {
+    public void seek(int level, int targetOrd) throws IOException {
       // unsafe; no bounds checking
       dataIn.seek(entry.ordOffsets[targetOrd]);
       arcCount = dataIn.readInt();
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
index 0c2832b..f82278a 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java
@@ -208,11 +208,12 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
     hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
     HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
 
+    // TODO: implement storing of hierarchical graph; for now stores only 0th level
     for (int ord = 0; ord < count; ord++) {
       // write graph
       offsets[ord] = graphData.getFilePointer() - graphDataOffset;
 
-      NeighborArray neighbors = graph.getNeighbors(ord);
+      NeighborArray neighbors = graph.getNeighbors(0, ord);
       int size = neighbors.size();
 
       // Destructively modify; it's ok we are discarding it after this
diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java
index f8f175a..4ff1e86 100644
--- a/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java
+++ b/lucene/core/src/java/org/apache/lucene/index/KnnGraphValues.java
@@ -35,17 +35,18 @@ public abstract class KnnGraphValues {
    * Move the pointer to exactly {@code target}, the id of a node in the graph. After this method
    * returns, call {@link #nextNeighbor()} to return successive (ordered) connected node ordinals.
    *
+   * @param level level of the graph
    * @param target must be a valid node in the graph, ie. &ge; 0 and &lt; {@link
    *     VectorValues#size()}.
    */
-  public abstract void seek(int target) throws IOException;
+  public abstract void seek(int level, int target) throws IOException;
 
   /** Returns the number of nodes in the graph */
   public abstract int size();
 
   /**
    * Iterates over the neighbor list. It is illegal to call this method after it returns
-   * NO_MORE_DOCS without calling {@link #seek(int)}, which resets the iterator.
+   * NO_MORE_DOCS without calling {@link #seek(int, int)}, which resets the iterator.
    *
    * @return a node ordinal in the graph, or NO_MORE_DOCS if the iteration is complete.
    */
@@ -61,7 +62,7 @@ public abstract class KnnGraphValues {
         }
 
         @Override
-        public void seek(int target) {}
+        public void seek(int level, int target) {}
 
         @Override
         public int size() {
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
index d1f0420..8f6a8f0 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
@@ -40,10 +40,10 @@ import org.apache.lucene.util.SparseFixedBitSet;
  * <h2>Hyperparameters</h2>
  *
  * <ul>
- *   <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2012 paper; it controls the
+ *   <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2014 paper; it controls the
  *       number of random entry points to sample.
  *   <li><code>beamWidth</code> in {@link HnswGraphBuilder} has the same meaning as <code>efConst
- *       </code> in the 2016 paper. It is the number of nearest neighbor candidates to track while
+ *       </code> in the 2018 paper. It is the number of nearest neighbor candidates to track while
  *       searching the graph for each newly inserted node.
  *   <li><code>maxConn</code> has the same meaning as <code>M</code> in the later paper; it controls
  *       how many of the <code>efConst</code> neighbors are connected to the new node
@@ -56,22 +56,28 @@ import org.apache.lucene.util.SparseFixedBitSet;
 public final class HnswGraph extends KnnGraphValues {
 
   private final int maxConn;
-
-  // Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to
-  // HnswBuilder, and the
-  // node values are the ordinals of those vectors.
-  private final List<NeighborArray> graph;
+  // graph is a list of graph levels.
+  // Each level is represented as List<NeighborArray> – nodes' connections on this level.
+  // Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors
+  // added to HnswBuilder, and the node values are the ordinals of those vectors.
+  private final List<List<NeighborArray>> graph;
 
   // KnnGraphValues iterator members
   private int upto;
   private NeighborArray cur;
 
-  HnswGraph(int maxConn) {
-    graph = new ArrayList<>();
-    // Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be
-    // about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM
-    graph.add(new NeighborArray(Math.max(32, maxConn / 4)));
+  HnswGraph(int maxConn, int numLevels, int levelOfFirstNode) {
     this.maxConn = maxConn;
+    this.graph = new ArrayList<>(numLevels);
+    for (int i = 0; i < numLevels; i++) {
+      graph.add(new ArrayList<>());
+    }
+    for (int i = 0; i <= levelOfFirstNode; i++) {
+      // Typically with diversity criteria we see nodes not fully occupied;
+      // average fanout seems to be about 1/2 maxConn.
+      // There is some indexing time penalty for under-allocating, but saves RAM
+      graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4)));
+    }
   }
 
   /**
@@ -89,6 +95,7 @@ public final class HnswGraph extends KnnGraphValues {
    * @param random a source of randomness, used for generating entry points to the graph
    * @return a priority queue holding the closest neighbors found
    */
+  // TODO: implement hierarchical search, currently searches only 0th level
   public static NeighborQueue search(
       float[] query,
       int topK,
@@ -137,7 +144,7 @@ public final class HnswGraph extends KnnGraphValues {
         }
       }
       int topCandidateNode = candidates.pop();
-      graphValues.seek(topCandidateNode);
+      graphValues.seek(0, topCandidateNode);
       int friendOrd;
       while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
         assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
@@ -166,25 +173,36 @@ public final class HnswGraph extends KnnGraphValues {
   /**
    * Returns the {@link NeighborQueue} connected to the given node.
    *
+   * @param level level of the graph
    * @param node the node whose neighbors are returned
    */
-  public NeighborArray getNeighbors(int node) {
-    return graph.get(node);
+  public NeighborArray getNeighbors(int level, int node) {
+    NeighborArray result = graph.get(level).get(node);
+    assert result != null;
+    return result;
   }
 
   @Override
   public int size() {
-    return graph.size();
+    return graph.get(0).size(); // all nodes are located on the 0th level
   }
 
-  int addNode() {
-    graph.add(new NeighborArray(maxConn + 1));
-    return graph.size() - 1;
+  // TODO: optimize RAM usage so not to store references for all nodes for levels > 0
+  public void addNode(int level, int node) {
+    if (level > 0) {
+      // Levels above 0th don't contain all nodes,
+      // so for missing nodes we add null NeighborArray
+      int nullsToAdd = node - graph.get(level).size();
+      for (int i = 0; i < nullsToAdd; i++) {
+        graph.get(level).add(null);
+      }
+    }
+    graph.get(level).add(new NeighborArray(maxConn + 1));
   }
 
   @Override
-  public void seek(int targetNode) {
-    cur = getNeighbors(targetNode);
+  public void seek(int level, int targetNode) {
+    cur = getNeighbors(level, targetNode);
     upto = -1;
   }
 
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index d12a731..f7362e3 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -84,7 +84,7 @@ public final class HnswGraphBuilder {
     }
     this.maxConn = maxConn;
     this.beamWidth = beamWidth;
-    this.hnsw = new HnswGraph(maxConn);
+    this.hnsw = new HnswGraph(maxConn, 1, 0);
     bound = BoundsChecker.create(similarityFunction.reversed);
     random = new Random(seed);
     scratch = new NeighborArray(Math.max(beamWidth, maxConn + 1));
@@ -109,7 +109,7 @@ public final class HnswGraphBuilder {
     long start = System.nanoTime(), t = start;
     // start at node 1! node 0 is added implicitly, in the constructor
     for (int node = 1; node < vectors.size(); node++) {
-      addGraphNode(vectors.vectorValue(node));
+      addGraphNode(node, vectors.vectorValue(node));
       if (node % 10000 == 0) {
         if (infoStream.isEnabled(HNSW_COMPONENT)) {
           long now = System.nanoTime();
@@ -133,13 +133,14 @@ public final class HnswGraphBuilder {
   }
 
   /** Inserts a doc with vector value to the graph */
-  void addGraphNode(float[] value) throws IOException {
+  // TODO: implement hierarchical graph building
+  void addGraphNode(int node, float[] value) throws IOException {
     // We pass 'null' for acceptOrds because there are no deletions while building the graph
     NeighborQueue candidates =
         HnswGraph.search(
             value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, null, random);
 
-    int node = hnsw.addNode();
+    hnsw.addNode(0, node);
 
     /* connect neighbors to the new node, using a diversity heuristic that chooses successive
      * nearest neighbors that are closer to the new node than they are to the previously-selected
@@ -158,7 +159,7 @@ public final class HnswGraphBuilder {
      * is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
      * since the node is new and has no prior neighbors).
      */
-    NeighborArray neighbors = hnsw.getNeighbors(node);
+    NeighborArray neighbors = hnsw.getNeighbors(0, node);
     assert neighbors.size() == 0; // new node
     popToScratch(candidates);
     selectDiverse(neighbors, scratch);
@@ -168,7 +169,7 @@ public final class HnswGraphBuilder {
     int size = neighbors.size();
     for (int i = 0; i < size; i++) {
       int nbr = neighbors.node[i];
-      NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
+      NeighborArray nbrNbr = hnsw.getNeighbors(0, nbr);
       nbrNbr.add(node, neighbors.score[i]);
       if (nbrNbr.size() > maxConn) {
         diversityUpdate(nbrNbr);
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index b035a2f..0c790d7 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -214,7 +214,7 @@ public class TestKnnGraph extends LuceneTestCase {
     int[] scratch = new int[maxConn];
     for (int node = 0; node < size; node++) {
       int n, count = 0;
-      values.seek(node);
+      values.seek(0, node);
       while ((n = values.nextNeighbor()) != NO_MORE_DOCS) {
         scratch[count++] = n;
         // graph[node][i++] = n;
@@ -352,7 +352,7 @@ public class TestKnnGraph extends LuceneTestCase {
             break;
           }
           int id = Integer.parseInt(reader.document(i).get("id"));
-          graphValues.seek(graphSize);
+          graphValues.seek(0, graphSize);
           // documents with KnnGraphValues have the expected vectors
           float[] scratch = vectorValues.vectorValue();
           assertArrayEquals(
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
index fcdf0aa..cc977c5 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java
@@ -256,7 +256,7 @@ public class KnnGraphTester {
           new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0);
       // start at node 1
       for (int i = 1; i < numDocs; i++) {
-        builder.addGraphNode(values.vectorValue(i));
+        builder.addGraphNode(i, values.vectorValue(i));
         System.out.println("\nITERATION " + i);
         dumpGraph(builder.hnsw);
       }
@@ -265,7 +265,7 @@ public class KnnGraphTester {
 
   private void dumpGraph(HnswGraph hnsw) {
     for (int i = 0; i < hnsw.size(); i++) {
-      NeighborArray neighbors = hnsw.getNeighbors(i);
+      NeighborArray neighbors = hnsw.getNeighbors(0, i);
       System.out.printf(Locale.ROOT, "%5d", i);
       NeighborArray sorted = new NeighborArray(neighbors.size());
       for (int j = 0; j < neighbors.size(); j++) {
@@ -297,7 +297,7 @@ public class KnnGraphTester {
     int count = 0;
     int[] leafHist = new int[numDocs];
     for (int node = 0; node < numDocs; node++) {
-      knnValues.seek(node);
+      knnValues.seek(0, node);
       int n = 0;
       while (knnValues.nextNeighbor() != NO_MORE_DOCS) {
         ++n;
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index bec0541..5fdd850 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -150,7 +150,7 @@ public class TestHnswGraph extends LuceneTestCase {
     // 45
     assertTrue("sum(result docs)=" + sum, sum < 75);
     for (int i = 0; i < nDoc; i++) {
-      NeighborArray neighbors = hnsw.getNeighbors(i);
+      NeighborArray neighbors = hnsw.getNeighbors(0, i);
       int[] nodes = neighbors.node;
       for (int j = 0; j < neighbors.size(); j++) {
         // all neighbors should be valid node ids.
@@ -252,15 +252,15 @@ public class TestHnswGraph extends LuceneTestCase {
             vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt());
     // node 0 is added by the builder constructor
     // builder.addGraphNode(vectors.vectorValue(0));
-    builder.addGraphNode(vectors.vectorValue(1));
-    builder.addGraphNode(vectors.vectorValue(2));
+    builder.addGraphNode(1, vectors.vectorValue(1));
+    builder.addGraphNode(2, vectors.vectorValue(2));
     // now every node has tried to attach every other node as a neighbor, but
     // some were excluded based on diversity check.
     assertNeighbors(builder.hnsw, 0, 1, 2);
     assertNeighbors(builder.hnsw, 1, 0);
     assertNeighbors(builder.hnsw, 2, 0);
 
-    builder.addGraphNode(vectors.vectorValue(3));
+    builder.addGraphNode(3, vectors.vectorValue(3));
     assertNeighbors(builder.hnsw, 0, 1, 2);
     // we added 3 here
     assertNeighbors(builder.hnsw, 1, 0, 3);
@@ -268,7 +268,7 @@ public class TestHnswGraph extends LuceneTestCase {
     assertNeighbors(builder.hnsw, 3, 1);
 
     // supplant an existing neighbor
-    builder.addGraphNode(vectors.vectorValue(4));
+    builder.addGraphNode(4, vectors.vectorValue(4));
     // 4 is the same distance from 0 that 2 is; we leave the existing node in place
     assertNeighbors(builder.hnsw, 0, 1, 2);
     // 4 is closer to 1 than either existing neighbor (0, 3). 3 fails diversity check with 4, so
@@ -279,7 +279,7 @@ public class TestHnswGraph extends LuceneTestCase {
     assertNeighbors(builder.hnsw, 3, 1, 4);
     assertNeighbors(builder.hnsw, 4, 1, 3);
 
-    builder.addGraphNode(vectors.vectorValue(5));
+    builder.addGraphNode(5, vectors.vectorValue(5));
     assertNeighbors(builder.hnsw, 0, 1, 2);
     assertNeighbors(builder.hnsw, 1, 0, 5);
     assertNeighbors(builder.hnsw, 2, 0);
@@ -291,7 +291,7 @@ public class TestHnswGraph extends LuceneTestCase {
 
   private void assertNeighbors(HnswGraph graph, int node, int... expected) {
     Arrays.sort(expected);
-    NeighborArray nn = graph.getNeighbors(node);
+    NeighborArray nn = graph.getNeighbors(0, node);
     int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size());
     Arrays.sort(actual);
     assertArrayEquals(
@@ -439,8 +439,8 @@ public class TestHnswGraph extends LuceneTestCase {
 
   private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException {
     for (int node = 0; node < size; node++) {
-      g.seek(node);
-      h.seek(node);
+      g.seek(0, node);
+      h.seek(0, node);
       assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h));
     }
   }