You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ma...@apache.org on 2022/09/08 21:04:06 UTC

[lucene] branch branch_9_4 updated: LUCENE-10592 Better estimate memory for HNSW graph (#11743)

This is an automated email from the ASF dual-hosted git repository.

mayya pushed a commit to branch branch_9_4
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9_4 by this push:
     new 3e4c892de89 LUCENE-10592 Better estimate memory for HNSW graph (#11743)
3e4c892de89 is described below

commit 3e4c892de89165ac0db4a0a28ac8aa12b17ce5d3
Author: Mayya Sharipova <ma...@elastic.co>
AuthorDate: Thu Sep 8 16:54:29 2022 -0400

    LUCENE-10592 Better estimate memory for HNSW graph (#11743)
    
    Better estimate memory used for OnHeapHnswGraph,
    as well as add tests.
    
    Also don't overallocate arrays in NeighborArray
    
    Relates to #992
---
 .../org/apache/lucene/util/hnsw/NeighborArray.java | 12 +++----
 .../apache/lucene/util/hnsw/OnHeapHnswGraph.java   | 20 +++++++----
 .../org/apache/lucene/util/hnsw/TestHnswGraph.java | 39 ++++++++++++++++------
 3 files changed, 48 insertions(+), 23 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
index 78224ed2358..ec1b5ec3e89 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
@@ -46,8 +46,8 @@ public class NeighborArray {
    * nodes.
    */
   public void add(int newNode, float newScore) {
-    if (size == node.length - 1) {
-      node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
+    if (size == node.length) {
+      node = ArrayUtil.grow(node);
       score = ArrayUtil.growExact(score, node.length);
     }
     if (size > 0) {
@@ -63,8 +63,8 @@ public class NeighborArray {
 
   /** Add a new node to the NeighborArray into a correct sort position according to its score. */
   public void insertSorted(int newNode, float newScore) {
-    if (size == node.length - 1) {
-      node = ArrayUtil.grow(node, (size + 1) * 3 / 2);
+    if (size == node.length) {
+      node = ArrayUtil.grow(node);
       score = ArrayUtil.growExact(score, node.length);
     }
     int insertionPoint =
@@ -104,8 +104,8 @@ public class NeighborArray {
   }
 
   public void removeIndex(int idx) {
-    System.arraycopy(node, idx + 1, node, idx, size - idx);
-    System.arraycopy(score, idx + 1, score, idx, size - idx);
+    System.arraycopy(node, idx + 1, node, idx, size - idx - 1);
+    System.arraycopy(score, idx + 1, score, idx, size - idx - 1);
     size--;
   }
 
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index 8cf6b54654c..78137c2a630 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -175,20 +175,28 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
     long neighborArrayBytes0 =
         nsize0 * (Integer.BYTES + Float.BYTES)
             + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
-            + RamUsageEstimator.NUM_BYTES_OBJECT_REF;
+            + RamUsageEstimator.NUM_BYTES_OBJECT_REF
+            + Integer.BYTES * 2;
     long neighborArrayBytes =
         nsize * (Integer.BYTES + Float.BYTES)
             + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
-            + RamUsageEstimator.NUM_BYTES_OBJECT_REF;
-
+            + RamUsageEstimator.NUM_BYTES_OBJECT_REF
+            + Integer.BYTES * 2;
     long total = 0;
     for (int l = 0; l < numLevels; l++) {
       int numNodesOnLevel = graph.get(l).size();
       if (l == 0) {
-        total += numNodesOnLevel * neighborArrayBytes0; // for graph;
+        total +=
+            numNodesOnLevel * neighborArrayBytes0
+                + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
       } else {
-        total += numNodesOnLevel * Integer.BYTES; // for nodesByLevel
-        total += numNodesOnLevel * neighborArrayBytes; // for graph;
+        total +=
+            nodesByLevel.get(l).length * Integer.BYTES
+                + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+                + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for nodesByLevel
+        total +=
+            numNodesOnLevel * neighborArrayBytes
+                + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph;
       }
     }
     return total;
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
index d6fc3779afd..1c88a728c96 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java
@@ -17,9 +17,12 @@
 
 package org.apache.lucene.util.hnsw;
 
+import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
 import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.apache.lucene.tests.util.RamUsageTester.ramUsed;
 import static org.apache.lucene.util.VectorUtil.toBytesRef;
 
+import com.carrotsearch.randomizedtesting.RandomizedTest;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -59,6 +62,7 @@ import org.apache.lucene.util.BitSet;
 import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.FixedBitSet;
+import org.apache.lucene.util.RamUsageEstimator;
 import org.apache.lucene.util.VectorUtil;
 import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
 import org.junit.Before;
@@ -71,15 +75,8 @@ public class TestHnswGraph extends LuceneTestCase {
 
   @Before
   public void setup() {
-    similarityFunction =
-        VectorSimilarityFunction.values()[
-            random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
-    if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
-      vectorEncoding =
-          VectorEncoding.values()[random().nextInt(VectorEncoding.values().length - 1) + 1];
-    } else {
-      vectorEncoding = VectorEncoding.FLOAT32;
-    }
+    similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
+    vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
   }
 
   // test writing out and reading in a graph gives the expected graph
@@ -158,8 +155,7 @@ public class TestHnswGraph extends LuceneTestCase {
     int M = random().nextInt(10) + 5;
     int beamWidth = random().nextInt(10) + 5;
     VectorSimilarityFunction similarityFunction =
-        VectorSimilarityFunction.values()[
-            random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
+        RandomizedTest.randomFrom(VectorSimilarityFunction.values());
     long seed = random().nextLong();
     HnswGraphBuilder.randSeed = seed;
     IndexWriterConfig iwc =
@@ -475,6 +471,27 @@ public class TestHnswGraph extends LuceneTestCase {
                 0));
   }
 
+  public void testRamUsageEstimate() throws IOException {
+    int size = atLeast(2000);
+    int dim = randomIntBetween(100, 1024);
+    int M = randomIntBetween(4, 96);
+
+    VectorSimilarityFunction similarityFunction =
+        RandomizedTest.randomFrom(VectorSimilarityFunction.values());
+    VectorEncoding vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
+    TestHnswGraph.RandomVectorValues vectors =
+        new TestHnswGraph.RandomVectorValues(size, dim, vectorEncoding, random());
+
+    HnswGraphBuilder<?> builder =
+        HnswGraphBuilder.create(
+            vectors, vectorEncoding, similarityFunction, M, M * 2, random().nextLong());
+    OnHeapHnswGraph hnsw = builder.build(vectors.copy());
+    long estimated = RamUsageEstimator.sizeOfObject(hnsw);
+    long actual = ramUsed(hnsw);
+
+    assertEquals((double) actual, (double) estimated, (double) actual * 0.3);
+  }
+
   @SuppressWarnings("unchecked")
   public void testDiversity() throws IOException {
     vectorEncoding = randomVectorEncoding();