You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@lucene.apache.org by GitBox <gi...@apache.org> on 2021/09/01 03:25:22 UTC

[GitHub] [lucene] msokolov commented on a change in pull request #267: LUCENE-10054 Handle hierarchy in graph construction and search

msokolov commented on a change in pull request #267:
URL: https://github.com/apache/lucene/pull/267#discussion_r699806490



##########
File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
##########
@@ -107,32 +113,82 @@ public static NeighborQueue search(
       Random random)
       throws IOException {
     int size = graphValues.size();
+    int boundedNumSeed = Math.min(numSeed, 2 * size);
+    NeighborQueue results;
+
+    if (graphValues.maxLevel() == 0) {
+      // search in SNW; generate a number of entry points randomly
+      final int[] eps = new int[boundedNumSeed];
+      for (int i = 0; i < boundedNumSeed; i++) {
+        eps[i] = random.nextInt(size);
+      }
+      return searchLevel(query, topK, 0, eps, vectors, similarityFunction, graphValues, acceptOrds);
+    } else {
+      // search in hierarchical SNW

Review comment:
       I notice you use `SNW` throughout, but elsewhere `HNSW` -- should we refer to `NSW` (navigable small-world) graphs?

##########
File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
##########
@@ -107,32 +113,82 @@ public static NeighborQueue search(
       Random random)
       throws IOException {
     int size = graphValues.size();
+    int boundedNumSeed = Math.min(numSeed, 2 * size);
+    NeighborQueue results;
+
+    if (graphValues.maxLevel() == 0) {
+      // search in SNW; generate a number of entry points randomly
+      final int[] eps = new int[boundedNumSeed];
+      for (int i = 0; i < boundedNumSeed; i++) {
+        eps[i] = random.nextInt(size);
+      }
+      return searchLevel(query, topK, 0, eps, vectors, similarityFunction, graphValues, acceptOrds);
+    } else {
+      // search in hierarchical SNW
+      int[] eps = new int[] {graphValues.entryNode()};
+      for (int level = graphValues.maxLevel(); level >= 1; level--) {
+        results =
+            HnswGraph.searchLevel(
+                query, 1, level, eps, vectors, similarityFunction, graphValues, acceptOrds);
+        eps = new int[] {results.pop()};
+      }
+      results =
+          HnswGraph.searchLevel(
+              query, boundedNumSeed, 0, eps, vectors, similarityFunction, graphValues, acceptOrds);
+      while (results.size() > topK) {
+        results.pop();
+      }
+      return results;
+    }
+  }
 
+  /**
+   * Searches for the nearest neighbors of a query vector in a given level
+   *
+   * @param query search query vector
+   * @param topK the number of nearest to query results to return

Review comment:
       Currently topK is always ==eps.length; I wonder if we need a topK parameter to searchLevel?

##########
File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
##########
@@ -107,32 +113,82 @@ public static NeighborQueue search(
       Random random)
       throws IOException {
     int size = graphValues.size();
+    int boundedNumSeed = Math.min(numSeed, 2 * size);
+    NeighborQueue results;
+
+    if (graphValues.maxLevel() == 0) {
+      // search in SNW; generate a number of entry points randomly
+      final int[] eps = new int[boundedNumSeed];
+      for (int i = 0; i < boundedNumSeed; i++) {
+        eps[i] = random.nextInt(size);

Review comment:
       we don't want repeats here, I think? At least, we don't allow them in the current NSW impl.

##########
File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
##########
@@ -146,20 +168,72 @@ void addGraphNode(int node, float[] value) throws IOException {
      * nearest neighbors that are closer to the new node than they are to the previously-selected
      * neighbors
      */
-    addDiverseNeighbors(node, candidates);
+    addDiverseNeighbors(0, node, candidates);
+  }
+
+  // build hierarchical navigable small world graph (multi-layered)
+  void buildHNSW(RandomAccessVectorValues vectors) throws IOException {
+    long start = System.nanoTime(), t = start;
+    // start at node 1! node 0 is added implicitly, in the constructor
+    for (int node = 1; node < vectors.size(); node++) {
+      addGraphNodeHNSW(node, vectors.vectorValue(node));
+      if (node % 10000 == 0) {

Review comment:
       can we refactor and share with the other place we do this?

##########
File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
##########
@@ -188,15 +244,28 @@ public int size() {
   }
 
   // TODO: optimize RAM usage so not to store references for all nodes for levels > 0
+  // TODO: add extra levels if level >= numLevels
   public void addNode(int level, int node) {
     if (level > 0) {
+      // if the new node introduces a new level, make this node the graph's new entry point
+      if (level > curMaxLevel) {
+        curMaxLevel = level;
+        entryNode = node;
+        // add more levels if needed
+        if (level >= graph.size()) {

Review comment:
       Wait - what does `graph.size()` mean here? Is it the number of nodes in level 0? Or the number of levels? Oh I remember - graph is simply a list of levels. I wonder if it would be clearer to call it `graphLevels`?

##########
File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
##########
@@ -107,32 +113,82 @@ public static NeighborQueue search(
       Random random)
       throws IOException {
     int size = graphValues.size();
+    int boundedNumSeed = Math.min(numSeed, 2 * size);
+    NeighborQueue results;
+
+    if (graphValues.maxLevel() == 0) {
+      // search in SNW; generate a number of entry points randomly
+      final int[] eps = new int[boundedNumSeed];
+      for (int i = 0; i < boundedNumSeed; i++) {
+        eps[i] = random.nextInt(size);
+      }
+      return searchLevel(query, topK, 0, eps, vectors, similarityFunction, graphValues, acceptOrds);
+    } else {
+      // search in hierarchical SNW
+      int[] eps = new int[] {graphValues.entryNode()};
+      for (int level = graphValues.maxLevel(); level >= 1; level--) {
+        results =
+            HnswGraph.searchLevel(
+                query, 1, level, eps, vectors, similarityFunction, graphValues, acceptOrds);
+        eps = new int[] {results.pop()};
+      }
+      results =
+          HnswGraph.searchLevel(
+              query, boundedNumSeed, 0, eps, vectors, similarityFunction, graphValues, acceptOrds);
+      while (results.size() > topK) {
+        results.pop();
+      }
+      return results;
+    }
+  }
 
+  /**
+   * Searches for the nearest neighbors of a query vector in a given level
+   *
+   * @param query search query vector
+   * @param topK the number of nearest to query results to return
+   * @param level level to search
+   * @param eps the entry points for search at this level
+   * @param vectors vector values
+   * @param similarityFunction similarity function
+   * @param graphValues the graph values
+   * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
+   *     {@code null} if they are all allowed to match.
+   * @return a priority queue holding the closest neighbors found
+   */
+  static NeighborQueue searchLevel(
+      float[] query,
+      int topK,
+      int level,
+      final int[] eps,
+      RandomAccessVectorValues vectors,
+      VectorSimilarityFunction similarityFunction,
+      KnnGraphValues graphValues,
+      Bits acceptOrds)
+      throws IOException {
+
+    int size = graphValues.size();
+    int queueSize = Math.max(eps.length, topK);
     // MIN heap, holding the top results
-    NeighborQueue results = new NeighborQueue(numSeed, similarityFunction.reversed);
+    NeighborQueue results = new NeighborQueue(queueSize, similarityFunction.reversed);
     // MAX heap, from which to pull the candidate nodes
-    NeighborQueue candidates = new NeighborQueue(numSeed, !similarityFunction.reversed);
-
+    NeighborQueue candidates = new NeighborQueue(queueSize, !similarityFunction.reversed);
     // set of ordinals that have been visited by search on this layer, used to avoid backtracking
     SparseFixedBitSet visited = new SparseFixedBitSet(size);
-    // get initial candidates at random
-    int boundedNumSeed = Math.min(numSeed, 2 * size);
-    for (int i = 0; i < boundedNumSeed; i++) {
-      int entryPoint = random.nextInt(size);
-      if (visited.get(entryPoint) == false) {
-        visited.set(entryPoint);
-        // explore the topK starting points of some random numSeed probes
-        float score = similarityFunction.compare(query, vectors.vectorValue(entryPoint));
-        candidates.add(entryPoint, score);
-        if (acceptOrds == null || acceptOrds.get(entryPoint)) {
-          results.add(entryPoint, score);
+
+    for (int i = 0; i < eps.length; i++) {

Review comment:
       might be cleaner with `for (int ep : eps)`

##########
File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
##########
@@ -107,32 +113,82 @@ public static NeighborQueue search(
       Random random)
       throws IOException {
     int size = graphValues.size();
+    int boundedNumSeed = Math.min(numSeed, 2 * size);
+    NeighborQueue results;
+
+    if (graphValues.maxLevel() == 0) {
+      // search in SNW; generate a number of entry points randomly
+      final int[] eps = new int[boundedNumSeed];
+      for (int i = 0; i < boundedNumSeed; i++) {
+        eps[i] = random.nextInt(size);
+      }
+      return searchLevel(query, topK, 0, eps, vectors, similarityFunction, graphValues, acceptOrds);
+    } else {
+      // search in hierarchical SNW
+      int[] eps = new int[] {graphValues.entryNode()};
+      for (int level = graphValues.maxLevel(); level >= 1; level--) {
+        results =
+            HnswGraph.searchLevel(
+                query, 1, level, eps, vectors, similarityFunction, graphValues, acceptOrds);
+        eps = new int[] {results.pop()};

Review comment:
       Maybe update `eps[0]` to avoid allocation?

##########
File path: lucene/core/src/test/org/apache/lucene/util/hnsw/TestHNSWGraph2.java
##########
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.hnsw;
+
+import static org.apache.lucene.index.TestKnnGraph.assertMaxConn;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util.VectorUtil;
+
+public class TestHNSWGraph2 extends LuceneTestCase {
+
+  // Tests that graph is consistent.
+  public void testGraphConsistent() throws IOException {
+    int dim = random().nextInt(100) + 1;
+    int nDoc = random().nextInt(100) + 1;
+    MockVectorValues values = new MockVectorValues(createRandomVectors(nDoc, dim, random()));
+    int beamWidth = random().nextInt(10) + 5;
+    int maxConn = random().nextInt(10) + 5;
+    double ml = 1 / Math.log(1.0 * maxConn);
+    long seed = random().nextLong();
+    VectorSimilarityFunction similarityFunction =
+        VectorSimilarityFunction.values()[
+            random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
+    HnswGraphBuilder builder =
+        new HnswGraphBuilder(values, similarityFunction, maxConn, beamWidth, seed, ml);
+    HnswGraph hnsw = builder.build(values);
+    assertConsistentGraph(hnsw, maxConn);
+  }
+
+  /**
+   * For each level of the graph, test that
+   *
+   * <p>1. There are no orphan nodes without any friends
+   *
+   * <p>2. If orphans are found, than the level must contain only 0 or a single node
+   *
+   * <p>3. If the number of nodes on the level doesn't exceed maxConn, assert that the graph is
+   * fully connected, i.e. any node is reachable from any other node.
+   *
+   * <p>4. If the number of nodes on the level exceeds maxConn, assert that maxConn is respected.
+   *
+   * <p>copy from TestKnnGraph::assertConsistentGraph with parts relevant only to in-memory graphs
+   * TODO: remove when hierarchical graph is implemented on disk
+   */
+  private static void assertConsistentGraph(HnswGraph hnsw, int maxConn) {
+    for (int level = hnsw.maxLevel(); level >= 0; level--) {
+      hnsw.seekLevel(level);
+
+      int[][] graph = new int[hnsw.size()][];
+      int nodesCount = 0;
+      boolean foundOrphan = false;
+
+      for (int node = hnsw.nextNodeOnLevel();
+          node != DocIdSetIterator.NO_MORE_DOCS;
+          node = hnsw.nextNodeOnLevel()) {
+        hnsw.seek(level, node);
+        int arc;
+        List<Integer> friends = new ArrayList<>();
+        while ((arc = hnsw.nextNeighbor()) != NO_MORE_DOCS) {
+          friends.add(arc);
+        }
+        if (friends.size() == 0) {
+          foundOrphan = true;
+        } else {
+          int[] friendsCopy = new int[friends.size()];
+          for (int f = 0; f < friends.size(); f++) {
+            friendsCopy[f] = friends.get(f);
+          }
+          graph[node] = friendsCopy;
+        }
+        nodesCount++;
+      }
+      // System.out.println("Level[" + level + "] has [" + nodesCount + "] nodes.");
+
+      assertFalse("No nodes on level [" + level + "]", nodesCount == 0);
+      if (nodesCount == 1) {
+        assertTrue(
+            "Graph with 1 node has unexpected neighbors on level [" + level + "]", foundOrphan);
+      } else {
+        assertFalse("Graph has orphan nodes with no friends on level [" + level + "]", foundOrphan);
+        if (maxConn > nodesCount) {
+          // assert that the graph in fully connected, i.e. any node can be reached from any other

Review comment:
       "is fully"

##########
File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
##########
@@ -267,4 +341,12 @@ private int findNonDiverse(NeighborArray neighbors) throws IOException {
     }
     return -1;
   }
+
+  private static int getRandomGraphLevel(double ml, Random random) {
+    float randFloat = random.nextFloat();
+    while (randFloat == 0.0f) {

Review comment:
       looks like a do-while to me!

##########
File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
##########
@@ -188,15 +244,28 @@ public int size() {
   }
 
   // TODO: optimize RAM usage so not to store references for all nodes for levels > 0
+  // TODO: add extra levels if level >= numLevels
   public void addNode(int level, int node) {
     if (level > 0) {
+      // if the new node introduces a new level, make this node the graph's new entry point
+      if (level > curMaxLevel) {
+        curMaxLevel = level;
+        entryNode = node;
+        // add more levels if needed
+        if (level >= graph.size()) {
+          for (int i = graph.size(); i <= level; i++) {
+            graph.add(new ArrayList<>());
+          }
+        }
+      }
       // Levels above 0th don't contain all nodes,

Review comment:
       to be more minimalistic here, I think we need three pointers (ints) - one within this level, one to the next level, and one to level 0. Not every level needs them all; eg for level 0 and 1 some of them are redundant. If we use the current data structure to store the within-level pointers and add two (possibly null) int[] arrays for the other pointers, I think it will be nice and compact. Fine to save it for later - just thinking out loud how this might end up looking




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@lucene.apache.org
For additional commands, e-mail: issues-help@lucene.apache.org